diff --git a/autoarray/__init__.py b/autoarray/__init__.py index f48c76ad3..c256417c6 100644 --- a/autoarray/__init__.py +++ b/autoarray/__init__.py @@ -9,6 +9,7 @@ from . import util from . import fixtures from . import mock as m +from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible from .dataset import preprocess from .dataset.abstract.dataset import AbstractDataset from .dataset.abstract.w_tilde import AbstractWTilde diff --git a/autoarray/dataset/interferometer/dataset.py b/autoarray/dataset/interferometer/dataset.py index 83cc1dc27..e1a8ee16e 100644 --- a/autoarray/dataset/interferometer/dataset.py +++ b/autoarray/dataset/interferometer/dataset.py @@ -164,6 +164,7 @@ def apply_w_tilde( batch_size: int = 128, show_progress: bool = False, show_memory: bool = False, + use_jax: bool = False, ): """ The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities @@ -192,7 +193,7 @@ def apply_w_tilde( if curvature_preload is None: - logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.") + logger.info("INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.") curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from( noise_map_real=self.noise_map.array.real, @@ -201,6 +202,7 @@ def apply_w_tilde( grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array, show_memory=show_memory, show_progress=show_progress, + use_jax=use_jax, ) dirty_image = self.transformer.image_from( diff --git a/autoarray/dataset/interferometer/w_tilde.py b/autoarray/dataset/interferometer/w_tilde.py index b9ce5857a..694c8a73f 100644 --- a/autoarray/dataset/interferometer/w_tilde.py +++ b/autoarray/dataset/interferometer/w_tilde.py @@ -1,9 +1,207 @@ +import json +import hashlib import numpy as np +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union from autoarray.dataset.abstract.w_tilde import AbstractWTilde from autoarray.mask.mask_2d import Mask2D +def _bbox_from_mask(mask_bool: np.ndarray) -> Tuple[int, int, int, int]: + """ + Return bbox (y_min, y_max, x_min, x_max) of the unmasked region. + mask_bool: True=masked, False=unmasked + """ + ys, xs = np.where(~mask_bool) + if ys.size == 0: + raise ValueError("Mask has no unmasked pixels; cannot compute bbox.") + return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max()) + + +def _mask_sha256(mask_bool: np.ndarray) -> str: + """ + Stable hash of the full boolean mask content (not just bbox). + """ + # Ensure contiguous, stable dtype + arr = np.ascontiguousarray(mask_bool.astype(np.uint8)) + return hashlib.sha256(arr.tobytes()).hexdigest() + + +def _as_pixel_scales_tuple(pixel_scales) -> Tuple[float, float]: + """ + Normalize pixel_scales to a stable 2-tuple of float. + Works with AutoArray pixel_scales objects or raw tuples. + """ + try: + # autoarray typically stores pixel_scales as tuple-like + return (float(pixel_scales[0]), float(pixel_scales[1])) + except Exception: + # fallback: treat as scalar + s = float(pixel_scales) + return (s, s) + + +def _np_float_tuple(x) -> Tuple[float, float]: + return (float(x[0]), float(x[1])) + + +def curvature_preload_metadata_from(real_space_mask) -> Dict[str, Any]: + """ + Build the minimal metadata required to decide whether a stored curvature_preload + can be reused for the current WTildeInterferometer instance. + + The preload depends on: + - the *rectangular FFT grid extent* used for offset evaluation (bbox / extent) + - pixel scales (radians per pixel) + - (usually) the exact mask shape and content (recommended to hash) + + Returns + ------- + dict + JSON-serializable metadata. + """ + mask_bool = np.asarray(real_space_mask, dtype=bool) + y_min, y_max, x_min, x_max = _bbox_from_mask(mask_bool) + y_extent = y_max - y_min + 1 + x_extent = x_max - x_min + 1 + + pixel_scales = _as_pixel_scales_tuple(real_space_mask.pixel_scales) + + meta = { + "format": "autoarray.w_tilde.curvature_preload.v1", + "mask_shape": tuple(mask_bool.shape), + "pixel_scales": pixel_scales, + "bbox_unmasked": (y_min, y_max, x_min, x_max), + "rect_shape": (y_extent, x_extent), + # full-content hash: safest way to prevent accidental reuse + "mask_sha256": _mask_sha256(mask_bool), + } + return meta + + +def is_preload_metadata_compatible( + real_space_mask, + meta: Dict[str, Any], + *, + require_mask_hash: bool = True, + atol: float = 0.0, +) -> Tuple[bool, str]: + """ + Compare loaded metadata against current instance. + + Parameters + ---------- + meta + Metadata dict loaded from disk. + require_mask_hash + If True, require the full mask sha256 to match (safest). + If False, only check bbox + shape + pixel scales. + atol + Tolerances for pixel scale comparisons (normally exact is fine + because these are configuration constants, but tolerances allow + for tiny float repr differences). + + Returns + ------- + (ok, reason) + ok: bool, True if compatible + reason: str, human-readable mismatch reason if not ok. + """ + current = curvature_preload_metadata_from(real_space_mask=real_space_mask) + + # 1) format version + if meta.get("format") != current["format"]: + return False, f"format mismatch: {meta.get('format')} != {current['format']}" + + # 2) mask shape + if tuple(meta.get("mask_shape", ())) != tuple(current["mask_shape"]): + return ( + False, + f"mask_shape mismatch: {meta.get('mask_shape')} != {current['mask_shape']}", + ) + + # 3) pixel scales + ps_saved = _np_float_tuple(meta.get("pixel_scales", (np.nan, np.nan))) + ps_curr = _np_float_tuple(current["pixel_scales"]) + + if not ( + np.isclose(ps_saved[0], ps_curr[0], atol=atol) + and np.isclose(ps_saved[1], ps_curr[1], atol=atol) + ): + return False, f"pixel_scales mismatch: {ps_saved} != {ps_curr}" + + # 4) bbox / rect shape + if tuple(meta.get("bbox_unmasked", ())) != tuple(current["bbox_unmasked"]): + return ( + False, + f"bbox_unmasked mismatch: {meta.get('bbox_unmasked')} != {current['bbox_unmasked']}", + ) + + if tuple(meta.get("rect_shape", ())) != tuple(current["rect_shape"]): + return ( + False, + f"rect_shape mismatch: {meta.get('rect_shape')} != {current['rect_shape']}", + ) + + # 5) full mask hash (optional but recommended) + if require_mask_hash: + if meta.get("mask_sha256") != current["mask_sha256"]: + return False, "mask_sha256 mismatch (mask content differs)" + + return True, "ok" + + +def load_curvature_preload_if_compatible( + file: Union[str, Path], + real_space_mask, + *, + require_mask_hash: bool = True, +) -> Optional[np.ndarray]: + """ + Load a saved curvature_preload if (and only if) it is compatible with the current mask geometry. + + Parameters + ---------- + file + Path to a previously saved NPZ. + require_mask_hash + If True, require the full mask content hash to match (safest). + If False, only bbox + shape + pixel scales are checked. + + Returns + ------- + np.ndarray + The loaded curvature_preload if compatible, otherwise raises ValueError. + """ + file = Path(file) + if file.suffix.lower() != ".npz": + file = file.with_suffix(".npz") + + if not file.exists(): + raise FileNotFoundError(str(file)) + + with np.load(file, allow_pickle=False) as npz: + if "curvature_preload" not in npz or "meta_json" not in npz: + msg = f"File does not contain required fields: {file}" + raise ValueError(msg) + + meta_json = str(npz["meta_json"].item()) + meta = json.loads(meta_json) + + ok, reason = is_preload_metadata_compatible( + meta=meta, + real_space_mask=real_space_mask, + require_mask_hash=require_mask_hash, + atol=1.0e-8, + ) + + if not ok: + raise ValueError(f"curvature_preload incompatible: {reason}") + + return np.asarray(npz["curvature_preload"]) + + class WTildeInterferometer(AbstractWTilde): def __init__( self, @@ -122,3 +320,49 @@ def rect_index_for_mask_index(self) -> np.ndarray: ) return rect_indices + + def save_curvature_preload( + self, + file: Union[str, Path], + *, + overwrite: bool = False, + ) -> Path: + """ + Save curvature_preload plus enough metadata to ensure it is only reused when safe. + + Uses NPZ so we can store: + - curvature_preload (array) + - meta_json (string) + + Parameters + ---------- + file + Path to save to. Recommended suffix: ".npz". + If you pass ".npy", we will still save an ".npz" next to it. + overwrite + If False and the file exists, raise FileExistsError. + + Returns + ------- + Path + The path actually written (will end with ".npz"). + """ + file = Path(file) + + # Force .npz (storing metadata safely) + if file.suffix.lower() != ".npz": + file = file.with_suffix(".npz") + + if file.exists() and not overwrite: + raise FileExistsError(f"File already exists: {file}") + + meta = curvature_preload_metadata_from(self.real_space_mask) + + meta_json = json.dumps(meta, sort_keys=True) + + np.savez_compressed( + file, + curvature_preload=np.asarray(self.curvature_preload), + meta_json=np.asarray(meta_json), + ) + return file diff --git a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py index b9fb29f76..0fa8925e1 100644 --- a/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py +++ b/autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py @@ -3,6 +3,7 @@ import numpy as np from tqdm import tqdm import os +import time logger = logging.getLogger(__name__) @@ -86,82 +87,130 @@ def w_tilde_curvature_preload_interferometer_from( chunk_k: int = 2048, show_progress: bool = False, show_memory: bool = False, + use_jax: bool = False, ) -> np.ndarray: """ - The matrix w_tilde is a matrix of dimensions [unmasked_image_pixels, unmasked_image_pixels] that encodes the - NUFFT of every pair of image pixels given the noise map. This can be used to efficiently compute the curvature - matrix via the mapping matrix, in a way that omits having to perform the NUFFT on every individual source pixel. - This provides a significant speed up for inversions of interferometer datasets with large number of visibilities. - - The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, - making it impossible to store in memory and its use in linear algebra calculations extremely. This methods creates - a preload matrix that can compute the matrix w_tilde via an efficient preloading scheme which exploits the - symmetries in the NUFFT. - - To compute w_tilde, one first defines a real space mask where every False entry is an unmasked pixel which is - used in the calculation, for example: - - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI This is an imaging.Mask2D, where: - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI x = `True` (Pixel is masked and excluded from lens) - IxIxIxIoIoIoIxIxIxIxI o = `False` (Pixel is not masked and included in lens) - IxIxIxIoIoIoIxIxIxIxI - IxIxIxIoIoIoIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - - - Here, there are 9 unmasked pixels. Indexing of each unmasked pixel goes from the top-left corner right and - downwards, therefore: - - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxI0I1I2IxIxIxIxI - IxIxIxI3I4I5IxIxIxIxI - IxIxIxI6I7I8IxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - IxIxIxIxIxIxIxIxIxIxI - - In the standard calculation of `w_tilde` it is a matrix of - dimensions [unmasked_image_pixels, unmasked_pixel_images], therefore for the example mask above it would be - dimensions [9, 9]. One performs a double for loop over `unmasked_image_pixels`, using the (y,x) spatial offset - between every possible pair of unmasked image pixels to precompute values that depend on the properties of the NUFFT. - - This calculation has a lot of redundancy, because it uses the (y,x) *spatial offset* between the image pixels. For - example, if two image pixel are next to one another by the same spacing the same value will be computed via the - NUFFT. For the example mask above: - - - The value precomputed for pixel pair [0,1] is the same as pixel pairs [1,2], [3,4], [4,5], [6,7] and [7,9]. - - The value precomputed for pixel pair [0,3] is the same as pixel pairs [1,4], [2,5], [3,6], [4,7] and [5,8]. - - The values of pixels paired with themselves are also computed repeatedly for the standard calculation (e.g. 9 - times using the mask above). - - The `curvature_preload` method instead only computes each value once. To do this, it stores the preload values in a - matrix of dimensions [shape_masked_pixels_y, shape_masked_pixels_x, 2], where `shape_masked_pixels` is the (y,x) - size of the vertical and horizontal extent of unmasked pixels, e.g. the spatial extent over which the real space - grid extends. - - Each entry in the matrix `curvature_preload[:,:,0]` provides the the precomputed NUFFT value mapping an image pixel - to a pixel offset by that much in the y and x directions, for example: - - - curvature_preload[0,0,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and - in the x direction by 0 - the values of pixels paired with themselves. - - curvature_preload[1,0,0] gives the precomputed values of image pixels that are offset in the y direction by 1 and - in the x direction by 0 - the values of pixel pairs [0,3], [1,4], [2,5], [3,6], [4,7] and [5,8] - - curvature_preload[0,1,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and - in the x direction by 1 - the values of pixel pairs [0,1], [1,2], [3,4], [4,5], [6,7] and [7,9]. - - Flipped pairs: - - The above preloaded values pair all image pixel NUFFT values when a pixel is to the right and / or down of the - first image pixel. However, one must also precompute pairs where the paired pixel is to the left of the host - pixels. These pairings are stored in `curvature_preload[:,:,1]`, and the ordering of these pairings is flipped in the - x direction to make it straight forward to use this matrix when computing w_tilde. + Compute the interferometer W-tilde curvature preload on a rectangular offset grid, + exploiting translational symmetry of the NUFFT kernel. + + This function computes a compact 2D preload array that depends only on the relative + (dy, dx) offsets between image pixels, avoiding construction of the dense + W-tilde matrix of shape [N_image_pixels, N_image_pixels]. + + The result can be used to rapidly assemble or apply W-tilde during curvature + matrix construction without performing a NUFFT per source pixel. + + ------------------------------------------------------------------------------- + Backend behaviour + ------------------------------------------------------------------------------- + - NumPy backend (use_jax=False, default): + * CPU execution + * Explicit Python loop over visibility chunks + * Supports progress bars and optional memory reporting + * Numerically closest to the original reference implementation + + - JAX backend (use_jax=True): + * JIT-compilable and GPU/TPU capable + * Uses fixed-size chunking and lax.fori_loop + * No Python-side loops during execution + * Progress bars and memory reporting are disabled + * Floating-point results are numerically stable but not guaranteed to be + bitwise-identical to NumPy due to parallel reduction order + + ------------------------------------------------------------------------------- + Numerical notes + ------------------------------------------------------------------------------- + The preload values are computed as: + + sum_k w_k * cos(dx * ku_k + dy * kv_k) + + where ku_k = 2π u_k and kv_k = 2π v_k. This corresponds to the real part of the + adjoint NUFFT evaluated on a uniform real-space offset grid. + + The chunking strategy controls temporary memory usage and GPU occupancy. Changing + `chunk_k` in JAX mode triggers recompilation. + + ------------------------------------------------------------------------------- + Full Description (Original Documentation) + ------------------------------------------------------------------------------- + The matrix w_tilde is a matrix of dimensions [unmasked_image_pixels, unmasked_image_pixels] that encodes the + NUFFT of every pair of image pixels given the noise map. This can be used to efficiently compute the curvature + matrix via the mapping matrix, in a way that omits having to perform the NUFFT on every individual source pixel. + This provides a significant speed up for inversions of interferometer datasets with large number of visibilities. + + The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's, + making it impossible to store in memory and its use in linear algebra calculations extremely. This methods creates + a preload matrix that can compute the matrix w_tilde via an efficient preloading scheme which exploits the + symmetries in the NUFFT. + + To compute w_tilde, one first defines a real space mask where every False entry is an unmasked pixel which is + used in the calculation, for example: + + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI This is an imaging.Mask2D, where: + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI x = `True` (Pixel is masked and excluded from lens) + IxIxIxIoIoIoIxIxIxIxI o = `False` (Pixel is not masked and included in lens) + IxIxIxIoIoIoIxIxIxIxI + IxIxIxIoIoIoIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + + Here, there are 9 unmasked pixels. Indexing of each unmasked pixel goes from the top-left corner right and + downwards, therefore: + + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxI0I1I2IxIxIxIxI + IxIxIxI3I4I5IxIxIxIxI + IxIxIxI6I7I8IxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + IxIxIxIxIxIxIxIxIxIxI + + In the standard calculation of `w_tilde` it is a matrix of + dimensions [unmasked_image_pixels, unmasked_pixel_images], therefore for the example mask above it would be + dimensions [9, 9]. One performs a double for loop over `unmasked_image_pixels`, using the (y,x) spatial offset + between every possible pair of unmasked image pixels to precompute values that depend on the properties of the NUFFT. + + This calculation has a lot of redundancy, because it uses the (y,x) *spatial offset* between the image pixels. For + example, if two image pixel are next to one another by the same spacing the same value will be computed via the + NUFFT. For the example mask above: + + - The value precomputed for pixel pair [0,1] is the same as pixel pairs [1,2], [3,4], [4,5], [6,7] and [7,9]. + - The value precomputed for pixel pair [0,3] is the same as pixel pairs [1,4], [2,5], [3,6], [4,7] and [5,8]. + - The values of pixels paired with themselves are also computed repeatedly for the standard calculation (e.g. 9 + times using the mask above). + + The `curvature_preload` method instead only computes each value once. To do this, it stores the preload values in a + matrix of dimensions [shape_masked_pixels_y, shape_masked_pixels_x, 2], where `shape_masked_pixels` is the (y,x) + size of the vertical and horizontal extent of unmasked pixels, e.g. the spatial extent over which the real space + grid extends. + + Each entry in the matrix `curvature_preload[:,:,0]` provides the the precomputed NUFFT value mapping an image pixel + to a pixel offset by that much in the y and x directions, for example: + + - curvature_preload[0,0,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and + in the x direction by 0 - the values of pixels paired with themselves. + - curvature_preload[1,0,0] gives the precomputed values of image pixels that are offset in the y direction by 1 and + in the x direction by 0 - the values of pixel pairs [0,3], [1,4], [2,5], [3,6], [4,7] and [5,8] + - curvature_preload[0,1,0] gives the precomputed values of image pixels that are offset in the y direction by 0 and + in the x direction by 1 - the values of pixel pairs [0,1], [1,2], [3,4], [4,5], [6,7] and [7,9]. + + Flipped pairs: + + The above preloaded values pair all image pixel NUFFT values when a pixel is to the right and / or down of the + first image pixel. However, one must also precompute pairs where the paired pixel is to the left of the host + pixels. These pairings are stored in `curvature_preload[:,:,1]`, and the ordering of these pairings is flipped in the + x direction to make it straight forward to use this matrix when computing w_tilde. + + Notes + ----- + - If use_jax=True, the JAX implementation is used (requires JAX installed). + - If use_jax=False, the NumPy implementation is used. Parameters ---------- @@ -176,15 +225,45 @@ def w_tilde_curvature_preload_interferometer_from( grid_radians_2d The 2D (y,x) grid of coordinates in radians corresponding to real-space mask within which the image that is Fourier transformed is computed. + """ + if use_jax: + return w_tilde_curvature_preload_interferometer_via_jax_from( + noise_map_real=noise_map_real, + uv_wavelengths=uv_wavelengths, + shape_masked_pixels_2d=shape_masked_pixels_2d, + grid_radians_2d=grid_radians_2d, + chunk_k=chunk_k, + ) - Returns - ------- - ndarray - A matrix that precomputes the values for fast computation of w_tilde. + return w_tilde_curvature_preload_interferometer_via_np_from( + noise_map_real=noise_map_real, + uv_wavelengths=uv_wavelengths, + shape_masked_pixels_2d=shape_masked_pixels_2d, + grid_radians_2d=grid_radians_2d, + chunk_k=chunk_k, + show_progress=show_progress, + show_memory=show_memory, + ) + + +def w_tilde_curvature_preload_interferometer_via_np_from( + noise_map_real: np.ndarray, + uv_wavelengths: np.ndarray, + shape_masked_pixels_2d, + grid_radians_2d: np.ndarray, + *, + chunk_k: int = 2048, + show_progress: bool = False, + show_memory: bool = False, +) -> np.ndarray: """ - # ----------------------------- - # Enforce float64 everywhere - # ----------------------------- + NumPy/CPU implementation of the interferometer W-tilde curvature preload. + + See `w_tilde_curvature_preload_interferometer_from` for full description. + """ + if chunk_k <= 0: + raise ValueError("chunk_k must be a positive integer") + noise_map_real = np.asarray(noise_map_real, dtype=np.float64) uv_wavelengths = np.asarray(uv_wavelengths, dtype=np.float64) grid_radians_2d = np.asarray(grid_radians_2d, dtype=np.float64) @@ -195,8 +274,9 @@ def w_tilde_curvature_preload_interferometer_from( gx = grid[..., 1] K = uv_wavelengths.shape[0] + n_chunks = (K + chunk_k - 1) // chunk_k - w = 1.0 / (noise_map_real**2) + w = 1.0 / (noise_map_real ** 2) ku = 2.0 * np.pi * uv_wavelengths[:, 0] kv = 2.0 * np.pi * uv_wavelengths[:, 1] @@ -208,65 +288,204 @@ def w_tilde_curvature_preload_interferometer_from( ym0, xm0 = gy[y_shape - 1, 0], gx[y_shape - 1, 0] ymm, xmm = gy[y_shape - 1, x_shape - 1], gx[y_shape - 1, x_shape - 1] - def accum_from_corner(y_ref, x_ref, gy_block, gx_block, label=""): + # ------------------------------------------------- + # Set up a single global progress bar + # ------------------------------------------------- + pbar = None + if show_progress: + + from tqdm import tqdm # type: ignore + + n_quadrants = 1 + if x_shape > 1: + n_quadrants += 1 + if y_shape > 1: + n_quadrants += 1 + if (y_shape > 1) and (x_shape > 1): + n_quadrants += 1 + + pbar = tqdm( + total=n_chunks * n_quadrants, + desc="Accumulating visibilities (W-tilde preload)", + ) + + def accum_from_corner_np(y_ref, x_ref, gy_block, gx_block): dy = y_ref - gy_block dx = x_ref - gx_block acc = np.zeros(gy_block.shape, dtype=np.float64) - iterator = range(0, K, chunk_k) - if show_progress: - iterator = tqdm( - iterator, - desc=f"Accumulating visibilities {label}", - total=(K + chunk_k - 1) // chunk_k, - ) - - for k0 in iterator: + for k0 in range(0, K, chunk_k): k1 = min(K, k0 + chunk_k) phase = dx[..., None] * ku[k0:k1] + dy[..., None] * kv[k0:k1] - acc += np.sum( - np.cos(phase) * w[k0:k1], - axis=2, - ) + acc += np.sum(np.cos(phase) * w[k0:k1], axis=2) + + if pbar is not None: + pbar.update(1) - if show_memory and show_progress: - _report_memory(acc) + if show_memory and show_progress and "_report_memory" in globals(): + globals()["_report_memory"](acc) return acc # ----------------------------- # Main quadrant (+,+) # ----------------------------- - out[:y_shape, :x_shape] = accum_from_corner(y00, x00, gy, gx, label="(+,+)") + out[:y_shape, :x_shape] = accum_from_corner_np(y00, x00, gy, gx) # ----------------------------- # Flip in x (+,-) # ----------------------------- if x_shape > 1: - block = accum_from_corner(y0m, x0m, gy[:, ::-1], gx[:, ::-1], label="(+,-)") + block = accum_from_corner_np(y0m, x0m, gy[:, ::-1], gx[:, ::-1]) out[:y_shape, -1:-(x_shape):-1] = block[:, 1:] # ----------------------------- # Flip in y (-,+) # ----------------------------- if y_shape > 1: - block = accum_from_corner(ym0, xm0, gy[::-1, :], gx[::-1, :], label="(-,+)") + block = accum_from_corner_np(ym0, xm0, gy[::-1, :], gx[::-1, :]) out[-1:-(y_shape):-1, :x_shape] = block[1:, :] # ----------------------------- # Flip in x and y (-,-) # ----------------------------- if (y_shape > 1) and (x_shape > 1): - block = accum_from_corner( - ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1], label="(-,-)" - ) + block = accum_from_corner_np(ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1]) out[-1:-(y_shape):-1, -1:-(x_shape):-1] = block[1:, 1:] + if pbar is not None: + pbar.close() + return out +def w_tilde_curvature_preload_interferometer_via_jax_from( + noise_map_real: np.ndarray, + uv_wavelengths: np.ndarray, + shape_masked_pixels_2d, + grid_radians_2d: np.ndarray, + *, + chunk_k: int = 2048, +) -> np.ndarray: + """ + JAX implementation of the interferometer W-tilde curvature preload. + + This version is intended for performance (CPU/GPU/TPU) and therefore: + - uses JIT compilation internally + - uses a compiled for-loop (lax.fori_loop) over fixed-size visibility chunks + - does not support progress bars or memory reporting (those require Python loops) + + See `w_tilde_curvature_preload_interferometer_from` for full description. + """ + import jax + import jax.numpy as jnp + + if chunk_k <= 0: + raise ValueError("chunk_k must be a positive integer") + + y_shape, x_shape = shape_masked_pixels_2d + + # Device arrays; keep float64 to match NumPy path as closely as possible. + noise_map_real_x = jnp.asarray(noise_map_real, dtype=jnp.float64) + uv_wavelengths_x = jnp.asarray(uv_wavelengths, dtype=jnp.float64) + grid_radians_2d_x = jnp.asarray(grid_radians_2d, dtype=jnp.float64) + + # Precompute weights and angular frequencies on device + w_x = 1.0 / (noise_map_real_x**2) + ku_x = 2.0 * jnp.pi * uv_wavelengths_x[:, 0] + kv_x = 2.0 * jnp.pi * uv_wavelengths_x[:, 1] + + grid = grid_radians_2d_x[:y_shape, :x_shape] + gy = grid[..., 0] + gx = grid[..., 1] + + # ----------------------------- + # IMPORTANT: pad so dynamic_slice(chunk_k) is always legal + # ----------------------------- + K = int(uv_wavelengths_x.shape[0]) # known at trace/compile time + n_chunks = (K + chunk_k - 1) // chunk_k + K_pad = n_chunks * chunk_k + pad_len = K_pad - K + + if pad_len > 0: + ku_x = jnp.pad(ku_x, (0, pad_len)) + kv_x = jnp.pad(kv_x, (0, pad_len)) + w_x = jnp.pad(w_x, (0, pad_len)) + + # A fixed [chunk_k] index vector used to mask the padded tail (last chunk). + idx = jnp.arange(chunk_k) + + def _compute_all_quadrants(gy, gx, *, chunk_k: int): + # Corner coordinates + y00, x00 = gy[0, 0], gx[0, 0] + y0m, x0m = gy[0, x_shape - 1], gx[0, x_shape - 1] + ym0, xm0 = gy[y_shape - 1, 0], gx[y_shape - 1, 0] + ymm, xmm = gy[y_shape - 1, x_shape - 1], gx[y_shape - 1, x_shape - 1] + + def accum_from_corner_jax(y_ref, x_ref, gy_block, gx_block): + dy = y_ref - gy_block + dx = x_ref - gx_block + + acc = jnp.zeros(gy_block.shape, dtype=jnp.float64) + + def body(i, acc_): + k0 = i * chunk_k + + # Always legal because ku_x/kv_x/w_x were padded to length K_pad. + ku_s = jax.lax.dynamic_slice(ku_x, (k0,), (chunk_k,)) + kv_s = jax.lax.dynamic_slice(kv_x, (k0,), (chunk_k,)) + w_s = jax.lax.dynamic_slice(w_x, (k0,), (chunk_k,)) + + # Mask the padded tail (only the first K entries are real). + valid = (idx + k0) < K + w_s = jnp.where(valid, w_s, 0.0) + + phase = ( + dx[..., None] * ku_s[None, None, :] + + dy[..., None] * kv_s[None, None, :] + ) + return acc_ + jnp.sum(jnp.cos(phase) * w_s[None, None, :], axis=2) + + return jax.lax.fori_loop(0, n_chunks, body, acc) + + out = jnp.zeros((2 * y_shape, 2 * x_shape), dtype=jnp.float64) + + # (+,+) + out = out.at[:y_shape, :x_shape].set(accum_from_corner_jax(y00, x00, gy, gx)) + + # (+,-) x-flip + if x_shape > 1: + block = accum_from_corner_jax(y0m, x0m, gy[:, ::-1], gx[:, ::-1]) + out = out.at[:y_shape, -1:-(x_shape):-1].set(block[:, 1:]) + + # (-,+) y-flip + if y_shape > 1: + block = accum_from_corner_jax(ym0, xm0, gy[::-1, :], gx[::-1, :]) + out = out.at[-1:-(y_shape):-1, :x_shape].set(block[1:, :]) + + # (-,-) x- and y-flip + if (y_shape > 1) and (x_shape > 1): + block = accum_from_corner_jax(ymm, xmm, gy[::-1, ::-1], gx[::-1, ::-1]) + out = out.at[-1:-(y_shape):-1, -1:-(x_shape):-1].set(block[1:, 1:]) + + return out + + _compute_all_quadrants_jit = jax.jit( + _compute_all_quadrants, static_argnames=("chunk_k",) + ) + + t0 = time.time() + out = _compute_all_quadrants_jit(gy, gx, chunk_k=chunk_k) + out.block_until_ready() # ensure timing includes actual device execution + t1 = time.time() + + logger.info("INTERFEROMETER - Finished W-Tilde (JAX) in %.3f seconds", (t1 - t0)) + + return np.asarray(out) + + def w_tilde_via_preload_from(curvature_preload, native_index_for_slim_index): """ Use the preloaded w_tilde matrix (see `curvature_preload_interferometer_from`) to compute diff --git a/test_autoarray/dataset/interferometer/test_dataset.py b/test_autoarray/dataset/interferometer/test_dataset.py index 28f296c3f..b53ff3e67 100644 --- a/test_autoarray/dataset/interferometer/test_dataset.py +++ b/test_autoarray/dataset/interferometer/test_dataset.py @@ -4,6 +4,7 @@ import shutil import autoarray as aa +import pytest from autoarray.operators import transformer @@ -148,3 +149,55 @@ def test__different_interferometer_without_mock_objects__customize_constructor_i assert (dataset.data == 1.0 + 1.0j * np.ones((19,))).all() assert (dataset.noise_map == 2.0 + 2.0j * np.ones((19,))).all() assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all() + + +def test__curvature_preload_metadata_from( + visibilities_7, + visibilities_noise_map_7, + uv_wavelengths_7x2, + mask_2d_7x7, +): + + dataset = aa.Interferometer( + data=visibilities_7, + noise_map=visibilities_noise_map_7, + uv_wavelengths=uv_wavelengths_7x2, + real_space_mask=mask_2d_7x7, + ) + + dataset = dataset.apply_w_tilde(use_jax=False) + + file = f"{test_data_path}/curvature_preload_metadata" + + dataset.w_tilde.save_curvature_preload( + file=file, + overwrite=True, + ) + + curvature_preload = aa.load_curvature_preload_if_compatible( + file=file, real_space_mask=dataset.real_space_mask + ) + + assert curvature_preload[0,0] == pytest.approx(1.75, 1.0e-4) + + real_space_mask_changed = np.array( + [ + [True, True, True, True, True, True, True], + [True, True, True, True, True, True, True], + [True, True, False, False, False, True, True], + [True, True, False, True, False, True, True], + [True, True, False, False, False, True, True], + [True, True, True, True, True, True, True], + [True, True, True, True, True, True, True], + ] + ) + + real_space_mask_changed = aa.Mask2D( + mask=real_space_mask_changed, pixel_scales=(1.0, 1.0) + ) + + with pytest.raises(ValueError): + + curvature_preload = aa.load_curvature_preload_if_compatible( + file=file, real_space_mask=real_space_mask_changed + )