Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from astropy.io import fits
import logging
import numpy as np
from pathlib import Path
from typing import Optional

from autoconf import cached_property
Expand All @@ -14,6 +16,8 @@

from autoarray.structures.arrays import array_2d_util

from autoarray.inversion.inversion.interferometer import inversion_interferometer_util

logger = logging.getLogger(__name__)


Expand All @@ -25,6 +29,7 @@ def __init__(
uv_wavelengths: np.ndarray,
real_space_mask,
transformer_class=TransformerNUFFT,
preprocessing_directory=None,
):
"""
An interferometer dataset, containing the visibilities data, noise-map, real-space msk, Fourier transformer and
Expand Down Expand Up @@ -86,6 +91,8 @@ def __init__(
uv_wavelengths=uv_wavelengths, real_space_mask=real_space_mask
)

self.preprocessing_directory = Path(preprocessing_directory) if preprocessing_directory is not None else None

@cached_property
def grids(self):
return GridsDataset(
Expand Down Expand Up @@ -132,6 +139,27 @@ def from_fits(
transformer_class=transformer_class,
)

def w_tilde_preprocessing(self):

if self.preprocessing_directory.is_dir():

filename = "{}/curvature_preload.fits".format(self.preprocessing_directory)

if not self.preprocessing_directory.isfile(filename):
print(
"The file {} does not exist".format(filename)
)
logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")

curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
noise_map_real=self.noise_map.real,
uv_wavelengths=self.uv_wavelengths,
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
grid_radians_2d=self.transformer.grid.mask.unmasked_grid_sub_1.in_radians.native,
)

fits.writeto(filename, data=curvature_preload)

@cached_property
def w_tilde(self):
"""
Expand All @@ -152,10 +180,8 @@ def w_tilde(self):

logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")

from autoarray.inversion.inversion import inversion_util_secret

curvature_preload = (
inversion_util_secret.w_tilde_curvature_preload_interferometer_from(
inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
noise_map_real=np.array(self.noise_map.real),
uv_wavelengths=np.array(self.uv_wavelengths),
shape_masked_pixels_2d=np.array(
Expand All @@ -167,7 +193,7 @@ def w_tilde(self):
)
)

w_matrix = inversion_util_secret.w_tilde_via_preload_from(
w_matrix = inversion_interferometer_util.w_tilde_via_preload_from(
w_tilde_preload=curvature_preload,
native_index_for_slim_index=self.real_space_mask.derive_indexes.native_for_slim,
)
Expand Down
5 changes: 0 additions & 5 deletions autoarray/inversion/inversion/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,6 @@ def inversion_interferometer_from(
-------
An `Inversion` whose type is determined by the input `dataset` and `settings`.
"""
try:
from autoarray.inversion.inversion import inversion_util_secret
except ImportError:
settings.use_w_tilde = False

if any(
isinstance(linear_obj, AbstractLinearObjFuncList)
for linear_obj in linear_obj_list
Expand Down
Loading
Loading