Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds JAX/GPU support for W-Tilde curvature preload computation in interferometry, significantly improving performance (100x+ speedup) for high-resolution ALMA datasets. It also includes utilities for safely loading precomputed data with metadata validation.
Changes:
- Adds import for
load_curvature_preload_if_compatibleutility function to the main package - Removes outdated documentation about NaN handling for isothermal mass profile deflections
- Updates
save_jsoncall in interferometer analysis (with a critical bug)
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| autogalaxy/init.py | Exposes the new curvature preload utility function for external use |
| autogalaxy/profiles/mass/total/isothermal.py | Removes outdated documentation about coordinate shifting to avoid NaN values |
| autogalaxy/interferometer/model/analysis.py | Modifies save_json call for transformer_class (introduces breaking change) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| paths.save_json( | ||
| "transformer_class", to_dict(self.dataset.transformer.__class__), "dataset" | ||
| "transformer_class", to_dict(self.dataset.transformer.__class__), |
There was a problem hiding this comment.
The removal of the third parameter "dataset" from this save_json call may break data loading in the aggregator. The aggregator loads this value using fit.value(name="dataset.transformer_class"), which expects the JSON to be saved in a "dataset" subdirectory. With this change, the file would be saved at the root level instead, causing the aggregator to fail when trying to load it. The third parameter should likely be restored to maintain compatibility with the aggregator.
| "transformer_class", to_dict(self.dataset.transformer.__class__), | |
| "transformer_class", to_dict(self.dataset.transformer.__class__), "dataset" |
The fast interferometry JAX GPU implementation (Jammy2211/PyAutoArray#201) uses a preload of the NUFFT in order to compute the
curvature_matrix.This
curvature_preloadmatrix is computed once, before lens modeling, and reused throughout lens modeling. For high number of visibilities and resolution real space mask, this calculation can take minutes or hours on a CPU.The previous PR did not convert this calculation to JAX or run on GPU.
This pull request makes the W-Tilde curvature preload computation support JAX and GPU, with profiling suggeting at least x100 speed up for high resolution datasets, with the calculation taking under a minute for the highest resolution / visibilities ALMA datasets tested.
It also includes utilities for safely saving and loading precomputed data, which check metadata to ensure the loaded data matches the data being analysed.