Skip to content

Feature/jax w tilde preload#202

Merged
Jammy2211 merged 7 commits intomainfrom
feature/jax_w_tilde_preload
Jan 17, 2026
Merged

Feature/jax w tilde preload#202
Jammy2211 merged 7 commits intomainfrom
feature/jax_w_tilde_preload

Conversation

@Jammy2211
Copy link
Owner

The fast interferometry JAX GPU implementation (#201) uses a preload of the NUFFT in order to compute the curvature_matrix.

This curvature_preload matrix is computed once, before lens modeling, and reused throughout lens modeling. For high number of visibilities and resolution real space mask, this calculation can take minutes or hours on a CPU.

The previous PR did not convert this calculation to JAX or run on GPU.

This pull request makes the W-Tilde curvature preload computation support JAX and GPU, with profiling suggeting at least x100 speed up for high resolution datasets, with the calculation taking under a minute for the highest resolution / visibilities ALMA datasets tested.

It also includes utilities for safely saving and loading precomputed data, which check metadata to ensure the loaded data matches the data being analysed.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds JAX/GPU support for the W-Tilde curvature preload computation in interferometer analysis, enabling significant performance improvements (100x+ speedup) for high-resolution datasets. The PR introduces utilities for safely saving and loading precomputed curvature data with metadata validation.

Changes:

  • Adds JAX implementation of W-Tilde curvature preload computation alongside existing NumPy implementation
  • Introduces metadata-based save/load utilities to ensure precomputed data is only reused when compatible with current analysis parameters
  • Updates API to expose use_jax parameter for enabling GPU acceleration

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
autoarray/__init__.py Exports new load_curvature_preload_if_compatible utility function
autoarray/dataset/interferometer/dataset.py Adds use_jax parameter to apply_w_tilde method and updates log message
autoarray/dataset/interferometer/w_tilde.py Implements metadata utilities and save/load methods for curvature preload
autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py Refactors main function into NumPy and JAX implementations with improved documentation
test_autoarray/dataset/interferometer/test_dataset.py Adds test for metadata validation when loading precomputed data

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@@ -1,9 +1,208 @@
import json
import hashlib
from dataclasses import dataclass
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataclass import is unused in this file. It should be removed to keep the imports clean.

Suggested change
from dataclasses import dataclass

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +102
rtol, atol
Tolerances for pixel scale comparisons (normally exact is fine
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring mentions rtol parameter but the function signature only has atol. Either remove rtol from the docstring or add it to the function signature.

Suggested change
rtol, atol
Tolerances for pixel scale comparisons (normally exact is fine
atol
Tolerance for pixel scale comparisons (normally exact is fine

Copilot uses AI. Check for mistakes.
Comment on lines +175 to +176
np.ndarray or None
The loaded curvature_preload if compatible, otherwise None (unless raise_on_mismatch=True).
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type documentation is incorrect. The function always raises ValueError if incompatible (line 201), it never returns None. The docstring should state: "The loaded curvature_preload if compatible, otherwise raises ValueError." Also, raise_on_mismatch parameter is mentioned but doesn't exist.

Suggested change
np.ndarray or None
The loaded curvature_preload if compatible, otherwise None (unless raise_on_mismatch=True).
np.ndarray
The loaded curvature_preload if compatible, otherwise raises ValueError.

Copilot uses AI. Check for mistakes.
if curvature_preload is None:

logger.info("INTERFEROMETER - Computing W-Tilde... May take a moment.")
logger.info("INTERFEROMETER Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.")
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The character between 'INTERFEROMETER' and 'Computing' appears to be an en-dash (U+2013) rather than a hyphen-minus. For consistency with the rest of the codebase (e.g., line 489 in inversion_interferometer_util.py uses regular hyphen), this should use a standard hyphen-minus character.

Suggested change
logger.info("INTERFEROMETER Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.")
logger.info("INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, extreme inputs may exceed hours.")

Copilot uses AI. Check for mistakes.
Comment on lines +424 to +425

def _compute_all_quadrants(gy, gx, *, chunk_k: int):
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested function _compute_all_quadrants captures multiple variables from the outer scope (K, n_chunks, idx, ku_x, kv_x, w_x, y_shape, x_shape) which can make the code harder to reason about. Consider passing these as explicit parameters or restructuring the code to make dependencies clearer.

Suggested change
def _compute_all_quadrants(gy, gx, *, chunk_k: int):
quadrant_context = {
"K": K,
"n_chunks": n_chunks,
"idx": idx,
"ku_x": ku_x,
"kv_x": kv_x,
"w_x": w_x,
"y_shape": y_shape,
"x_shape": x_shape,
}
def _compute_all_quadrants(gy, gx, *, chunk_k: int):
K = quadrant_context["K"]
n_chunks = quadrant_context["n_chunks"]
idx = quadrant_context["idx"]
ku_x = quadrant_context["ku_x"]
kv_x = quadrant_context["kv_x"]
w_x = quadrant_context["w_x"]
y_shape = quadrant_context["y_shape"]
x_shape = quadrant_context["x_shape"]

Copilot uses AI. Check for mistakes.
assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all()


def test__curvature_preload_metadata_from(
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only validates the error case (incompatible mask) but doesn't verify that the loaded preload data is correct when it is compatible. Consider adding an assertion to check that the loaded curvature_preload matches the original saved data.

Copilot uses AI. Check for mistakes.

with pytest.raises(ValueError):

curvature_preload = aa.load_curvature_preload_if_compatible(
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable curvature_preload is not used.

Suggested change
curvature_preload = aa.load_curvature_preload_if_compatible(
aa.load_curvature_preload_if_compatible(

Copilot uses AI. Check for mistakes.
overwrite=True,
)

curvature_preload = aa.load_curvature_preload_if_compatible(
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'curvature_preload' is unnecessary as it is redefined before this value is used.

Suggested change
curvature_preload = aa.load_curvature_preload_if_compatible(
aa.load_curvature_preload_if_compatible(

Copilot uses AI. Check for mistakes.
if show_memory and show_progress and "_report_memory" in globals():
try:
globals()["_report_memory"](acc)
except Exception:
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except Exception:
except Exception:
# _report_memory is a best-effort diagnostic; ignore any errors it raises.

Copilot uses AI. Check for mistakes.
@Jammy2211 Jammy2211 merged commit e3ce0f7 into main Jan 17, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants