Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autoarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import util
from . import fixtures
from . import mock as m
from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible
from .dataset import preprocess
from .dataset.abstract.dataset import AbstractDataset
from .dataset.abstract.w_tilde import AbstractWTilde
Expand Down
4 changes: 3 additions & 1 deletion autoarray/dataset/interferometer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def apply_w_tilde(
batch_size: int = 128,
show_progress: bool = False,
show_memory: bool = False,
use_jax: bool = False,
):
"""
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities
Expand Down Expand Up @@ -192,7 +193,7 @@ def apply_w_tilde(

if curvature_preload is None:

logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
logger.info("INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.")

curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
noise_map_real=self.noise_map.array.real,
Expand All @@ -201,6 +202,7 @@ def apply_w_tilde(
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
show_memory=show_memory,
show_progress=show_progress,
use_jax=use_jax,
)

dirty_image = self.transformer.image_from(
Expand Down
244 changes: 244 additions & 0 deletions autoarray/dataset/interferometer/w_tilde.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,207 @@
import json
import hashlib
import numpy as np
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

from autoarray.dataset.abstract.w_tilde import AbstractWTilde
from autoarray.mask.mask_2d import Mask2D


def _bbox_from_mask(mask_bool: np.ndarray) -> Tuple[int, int, int, int]:
"""
Return bbox (y_min, y_max, x_min, x_max) of the unmasked region.
mask_bool: True=masked, False=unmasked
"""
ys, xs = np.where(~mask_bool)
if ys.size == 0:
raise ValueError("Mask has no unmasked pixels; cannot compute bbox.")
return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max())


def _mask_sha256(mask_bool: np.ndarray) -> str:
"""
Stable hash of the full boolean mask content (not just bbox).
"""
# Ensure contiguous, stable dtype
arr = np.ascontiguousarray(mask_bool.astype(np.uint8))
return hashlib.sha256(arr.tobytes()).hexdigest()


def _as_pixel_scales_tuple(pixel_scales) -> Tuple[float, float]:
"""
Normalize pixel_scales to a stable 2-tuple of float.
Works with AutoArray pixel_scales objects or raw tuples.
"""
try:
# autoarray typically stores pixel_scales as tuple-like
return (float(pixel_scales[0]), float(pixel_scales[1]))
except Exception:
# fallback: treat as scalar
s = float(pixel_scales)
return (s, s)


def _np_float_tuple(x) -> Tuple[float, float]:
return (float(x[0]), float(x[1]))


def curvature_preload_metadata_from(real_space_mask) -> Dict[str, Any]:
"""
Build the minimal metadata required to decide whether a stored curvature_preload
can be reused for the current WTildeInterferometer instance.

The preload depends on:
- the *rectangular FFT grid extent* used for offset evaluation (bbox / extent)
- pixel scales (radians per pixel)
- (usually) the exact mask shape and content (recommended to hash)

Returns
-------
dict
JSON-serializable metadata.
"""
mask_bool = np.asarray(real_space_mask, dtype=bool)
y_min, y_max, x_min, x_max = _bbox_from_mask(mask_bool)
y_extent = y_max - y_min + 1
x_extent = x_max - x_min + 1

pixel_scales = _as_pixel_scales_tuple(real_space_mask.pixel_scales)

meta = {
"format": "autoarray.w_tilde.curvature_preload.v1",
"mask_shape": tuple(mask_bool.shape),
"pixel_scales": pixel_scales,
"bbox_unmasked": (y_min, y_max, x_min, x_max),
"rect_shape": (y_extent, x_extent),
# full-content hash: safest way to prevent accidental reuse
"mask_sha256": _mask_sha256(mask_bool),
}
return meta


def is_preload_metadata_compatible(
real_space_mask,
meta: Dict[str, Any],
*,
require_mask_hash: bool = True,
atol: float = 0.0,
) -> Tuple[bool, str]:
"""
Compare loaded metadata against current instance.

Parameters
----------
meta
Metadata dict loaded from disk.
require_mask_hash
If True, require the full mask sha256 to match (safest).
If False, only check bbox + shape + pixel scales.
atol
Tolerances for pixel scale comparisons (normally exact is fine
because these are configuration constants, but tolerances allow
for tiny float repr differences).

Returns
-------
(ok, reason)
ok: bool, True if compatible
reason: str, human-readable mismatch reason if not ok.
"""
current = curvature_preload_metadata_from(real_space_mask=real_space_mask)

# 1) format version
if meta.get("format") != current["format"]:
return False, f"format mismatch: {meta.get('format')} != {current['format']}"

# 2) mask shape
if tuple(meta.get("mask_shape", ())) != tuple(current["mask_shape"]):
return (
False,
f"mask_shape mismatch: {meta.get('mask_shape')} != {current['mask_shape']}",
)

# 3) pixel scales
ps_saved = _np_float_tuple(meta.get("pixel_scales", (np.nan, np.nan)))
ps_curr = _np_float_tuple(current["pixel_scales"])

if not (
np.isclose(ps_saved[0], ps_curr[0], atol=atol)
and np.isclose(ps_saved[1], ps_curr[1], atol=atol)
):
return False, f"pixel_scales mismatch: {ps_saved} != {ps_curr}"

# 4) bbox / rect shape
if tuple(meta.get("bbox_unmasked", ())) != tuple(current["bbox_unmasked"]):
return (
False,
f"bbox_unmasked mismatch: {meta.get('bbox_unmasked')} != {current['bbox_unmasked']}",
)

if tuple(meta.get("rect_shape", ())) != tuple(current["rect_shape"]):
return (
False,
f"rect_shape mismatch: {meta.get('rect_shape')} != {current['rect_shape']}",
)

# 5) full mask hash (optional but recommended)
if require_mask_hash:
if meta.get("mask_sha256") != current["mask_sha256"]:
return False, "mask_sha256 mismatch (mask content differs)"

return True, "ok"


def load_curvature_preload_if_compatible(
file: Union[str, Path],
real_space_mask,
*,
require_mask_hash: bool = True,
) -> Optional[np.ndarray]:
"""
Load a saved curvature_preload if (and only if) it is compatible with the current mask geometry.

Parameters
----------
file
Path to a previously saved NPZ.
require_mask_hash
If True, require the full mask content hash to match (safest).
If False, only bbox + shape + pixel scales are checked.

Returns
-------
np.ndarray
The loaded curvature_preload if compatible, otherwise raises ValueError.
"""
file = Path(file)
if file.suffix.lower() != ".npz":
file = file.with_suffix(".npz")

if not file.exists():
raise FileNotFoundError(str(file))

with np.load(file, allow_pickle=False) as npz:
if "curvature_preload" not in npz or "meta_json" not in npz:
msg = f"File does not contain required fields: {file}"
raise ValueError(msg)

meta_json = str(npz["meta_json"].item())
meta = json.loads(meta_json)

ok, reason = is_preload_metadata_compatible(
meta=meta,
real_space_mask=real_space_mask,
require_mask_hash=require_mask_hash,
atol=1.0e-8,
)

if not ok:
raise ValueError(f"curvature_preload incompatible: {reason}")

return np.asarray(npz["curvature_preload"])


class WTildeInterferometer(AbstractWTilde):
def __init__(
self,
Expand Down Expand Up @@ -122,3 +320,49 @@ def rect_index_for_mask_index(self) -> np.ndarray:
)

return rect_indices

def save_curvature_preload(
self,
file: Union[str, Path],
*,
overwrite: bool = False,
) -> Path:
"""
Save curvature_preload plus enough metadata to ensure it is only reused when safe.

Uses NPZ so we can store:
- curvature_preload (array)
- meta_json (string)

Parameters
----------
file
Path to save to. Recommended suffix: ".npz".
If you pass ".npy", we will still save an ".npz" next to it.
overwrite
If False and the file exists, raise FileExistsError.

Returns
-------
Path
The path actually written (will end with ".npz").
"""
file = Path(file)

# Force .npz (storing metadata safely)
if file.suffix.lower() != ".npz":
file = file.with_suffix(".npz")

if file.exists() and not overwrite:
raise FileExistsError(f"File already exists: {file}")

meta = curvature_preload_metadata_from(self.real_space_mask)

meta_json = json.dumps(meta, sort_keys=True)

np.savez_compressed(
file,
curvature_preload=np.asarray(self.curvature_preload),
meta_json=np.asarray(meta_json),
)
return file
Loading
Loading