Conversation
…e mask meta data checks
There was a problem hiding this comment.
Pull request overview
This pull request adds JAX/GPU support for the W-Tilde curvature preload computation in interferometer analysis, enabling significant performance improvements (100x+ speedup) for high-resolution datasets. The PR introduces utilities for safely saving and loading precomputed curvature data with metadata validation.
Changes:
- Adds JAX implementation of W-Tilde curvature preload computation alongside existing NumPy implementation
- Introduces metadata-based save/load utilities to ensure precomputed data is only reused when compatible with current analysis parameters
- Updates API to expose
use_jaxparameter for enabling GPU acceleration
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
autoarray/__init__.py |
Exports new load_curvature_preload_if_compatible utility function |
autoarray/dataset/interferometer/dataset.py |
Adds use_jax parameter to apply_w_tilde method and updates log message |
autoarray/dataset/interferometer/w_tilde.py |
Implements metadata utilities and save/load methods for curvature preload |
autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py |
Refactors main function into NumPy and JAX implementations with improved documentation |
test_autoarray/dataset/interferometer/test_dataset.py |
Adds test for metadata validation when loading precomputed data |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -1,9 +1,208 @@ | |||
| import json | |||
| import hashlib | |||
| from dataclasses import dataclass | |||
There was a problem hiding this comment.
The dataclass import is unused in this file. It should be removed to keep the imports clean.
| from dataclasses import dataclass |
| rtol, atol | ||
| Tolerances for pixel scale comparisons (normally exact is fine |
There was a problem hiding this comment.
The docstring mentions rtol parameter but the function signature only has atol. Either remove rtol from the docstring or add it to the function signature.
| rtol, atol | |
| Tolerances for pixel scale comparisons (normally exact is fine | |
| atol | |
| Tolerance for pixel scale comparisons (normally exact is fine |
| np.ndarray or None | ||
| The loaded curvature_preload if compatible, otherwise None (unless raise_on_mismatch=True). |
There was a problem hiding this comment.
The return type documentation is incorrect. The function always raises ValueError if incompatible (line 201), it never returns None. The docstring should state: "The loaded curvature_preload if compatible, otherwise raises ValueError." Also, raise_on_mismatch parameter is mentioned but doesn't exist.
| np.ndarray or None | |
| The loaded curvature_preload if compatible, otherwise None (unless raise_on_mismatch=True). | |
| np.ndarray | |
| The loaded curvature_preload if compatible, otherwise raises ValueError. |
| if curvature_preload is None: | ||
|
|
||
| logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.") | ||
| logger.info("INTERFEROMETER – Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.") |
There was a problem hiding this comment.
The character between 'INTERFEROMETER' and 'Computing' appears to be an en-dash (U+2013) rather than a hyphen-minus. For consistency with the rest of the codebase (e.g., line 489 in inversion_interferometer_util.py uses regular hyphen), this should use a standard hyphen-minus character.
| logger.info("INTERFEROMETER – Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.") | |
| logger.info("INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.") |
|
|
||
| def _compute_all_quadrants(gy, gx, *, chunk_k: int): |
There was a problem hiding this comment.
The nested function _compute_all_quadrants captures multiple variables from the outer scope (K, n_chunks, idx, ku_x, kv_x, w_x, y_shape, x_shape) which can make the code harder to reason about. Consider passing these as explicit parameters or restructuring the code to make dependencies clearer.
| def _compute_all_quadrants(gy, gx, *, chunk_k: int): | |
| quadrant_context = { | |
| "K": K, | |
| "n_chunks": n_chunks, | |
| "idx": idx, | |
| "ku_x": ku_x, | |
| "kv_x": kv_x, | |
| "w_x": w_x, | |
| "y_shape": y_shape, | |
| "x_shape": x_shape, | |
| } | |
| def _compute_all_quadrants(gy, gx, *, chunk_k: int): | |
| K = quadrant_context["K"] | |
| n_chunks = quadrant_context["n_chunks"] | |
| idx = quadrant_context["idx"] | |
| ku_x = quadrant_context["ku_x"] | |
| kv_x = quadrant_context["kv_x"] | |
| w_x = quadrant_context["w_x"] | |
| y_shape = quadrant_context["y_shape"] | |
| x_shape = quadrant_context["x_shape"] |
| assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all() | ||
|
|
||
|
|
||
| def test__curvature_preload_metadata_from( |
There was a problem hiding this comment.
The test only validates the error case (incompatible mask) but doesn't verify that the loaded preload data is correct when it is compatible. Consider adding an assertion to check that the loaded curvature_preload matches the original saved data.
|
|
||
| with pytest.raises(ValueError): | ||
|
|
||
| curvature_preload = aa.load_curvature_preload_if_compatible( |
There was a problem hiding this comment.
Variable curvature_preload is not used.
| curvature_preload = aa.load_curvature_preload_if_compatible( | |
| aa.load_curvature_preload_if_compatible( |
| overwrite=True, | ||
| ) | ||
|
|
||
| curvature_preload = aa.load_curvature_preload_if_compatible( |
| if show_memory and show_progress and "_report_memory" in globals(): | ||
| try: | ||
| globals()["_report_memory"](acc) | ||
| except Exception: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except Exception: | |
| except Exception: | |
| # _report_memory is a best-effort diagnostic; ignore any errors it raises. |
The fast interferometry JAX GPU implementation (#201) uses a preload of the NUFFT in order to compute the
curvature_matrix.This
curvature_preloadmatrix is computed once, before lens modeling, and reused throughout lens modeling. For high number of visibilities and resolution real space mask, this calculation can take minutes or hours on a CPU.The previous PR did not convert this calculation to JAX or run on GPU.
This pull request makes the W-Tilde curvature preload computation support JAX and GPU, with profiling suggeting at least x100 speed up for high resolution datasets, with the calculation taking under a minute for the highest resolution / visibilities ALMA datasets tested.
It also includes utilities for safely saving and loading precomputed data, which check metadata to ensure the loaded data matches the data being analysed.