Skip to content

Feature/jax w tilde preload#264

Merged
Jammy2211 merged 4 commits intomainfrom
feature/jax_w_tilde_preload
Jan 17, 2026
Merged

Feature/jax w tilde preload#264
Jammy2211 merged 4 commits intomainfrom
feature/jax_w_tilde_preload

Conversation

@Jammy2211
Copy link
Owner

The fast interferometry JAX GPU implementation (Jammy2211/PyAutoArray#201) uses a preload of the NUFFT in order to compute the curvature_matrix.

This curvature_preload matrix is computed once, before lens modeling, and reused throughout lens modeling. For high number of visibilities and resolution real space mask, this calculation can take minutes or hours on a CPU.

The previous PR did not convert this calculation to JAX or run on GPU.

This pull request makes the W-Tilde curvature preload computation support JAX and GPU, with profiling suggeting at least x100 speed up for high resolution datasets, with the calculation taking under a minute for the highest resolution / visibilities ALMA datasets tested.

It also includes utilities for safely saving and loading precomputed data, which check metadata to ensure the loaded data matches the data being analysed.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds JAX/GPU support for W-Tilde curvature preload computation in interferometry, significantly improving performance (100x+ speedup) for high-resolution ALMA datasets. It also includes utilities for safely loading precomputed data with metadata validation.

Changes:

  • Adds import for load_curvature_preload_if_compatible utility function to the main package
  • Removes outdated documentation about NaN handling for isothermal mass profile deflections
  • Updates save_json call in interferometer analysis (with a critical bug)

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
autogalaxy/init.py Exposes the new curvature preload utility function for external use
autogalaxy/profiles/mass/total/isothermal.py Removes outdated documentation about coordinate shifting to avoid NaN values
autogalaxy/interferometer/model/analysis.py Modifies save_json call for transformer_class (introduces breaking change)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


paths.save_json(
"transformer_class", to_dict(self.dataset.transformer.__class__), "dataset"
"transformer_class", to_dict(self.dataset.transformer.__class__),
Copy link

Copilot AI Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the third parameter "dataset" from this save_json call may break data loading in the aggregator. The aggregator loads this value using fit.value(name="dataset.transformer_class"), which expects the JSON to be saved in a "dataset" subdirectory. With this change, the file would be saved at the root level instead, causing the aggregator to fail when trying to load it. The third parameter should likely be restored to maintain compatibility with the aggregator.

Suggested change
"transformer_class", to_dict(self.dataset.transformer.__class__),
"transformer_class", to_dict(self.dataset.transformer.__class__), "dataset"

Copilot uses AI. Check for mistakes.
@Jammy2211 Jammy2211 merged commit 5b95eae into main Jan 17, 2026
8 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax_w_tilde_preload branch February 13, 2026 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants