Releases: Jammy2211/PyAutoLens
Pixelization API + Cosmology
Simplified JAX Pixelization API
The pixelization API no longer needs the preload arrays to perform modeling in JAX, simplifying the use of pixelized source reconstructions.
For example, the API to set up a rectangular mesh now only needs its shape:
"""
__Mesh Shape__
The `mesh_shape` parameter defines number of pixels used by the rectangular mesh to reconstruct the source,
set below to 28 x 28.
The `mesh_shape` must be fixed before modeling and cannot be a free parameter of the model, because JAX uses the
mesh shape to define static shaped arrays which use the mesh to reconstruct the source. For a rectangular
mesh, the same number of pixels must be used in the y and x directions.
__Edge Zeroing__
By default, all pixels at the edge of the mesh in the source-plane are forced to solutions of zero brightness by
the linear algebra solver. This prevents unphysical solutions where pixels at the edge of the mesh reconstruct
bright surface brightnesses, often because they fit residuals from the lens light subtraction.
For a rectangular mesh, the source code computes edge pixels internally using the known
pixels at the edge of the mesh.
"""
mesh_pixels_yx = 28
mesh_shape = (mesh_pixels_yx, mesh_pixels_yx)
Pull Request: https://github.com/Jammy2211/PyAutoLens/pulls?q=is%3Apr+is%3Aclosed
Example: https://github.com/Jammy2211/autolens_workspace/blob/release/notebooks/imaging/features/pixelization/modeling.ipynb
Cosmology JAX
Astropy Cosmology imports all removed, with all Cosmology calculations now using in built PyAutoLens functions which fully support JAX:
JAX cored and elliptical NFW
Both now fully supported:
Jammy2211/PyAutoGalaxy#279
Jammy2211/PyAutoGalaxy#280
Adaptive Matern Kernel
Matern kernel with regularization which adapts to source's brightness implemented: Jammy2211/PyAutoArray#214
Tests on rectangular mesh show this is the best performing regularization scheme for the rectangular mesh.
Mesh Refactor:
The above API changes are part of a larger refactoring of the pixliezed source module in autoarray, with the mesh and mapper modules now much cleaner in terms of code:
Pull Request: #388
Convolution Refactor
The PSF convolution h as been moved to an operator module and also cleaned up, including better internal padding of data when light outside the mask blurs in.
Pull Requests:
github.com/Jammy2211/PyAutoArray/pulls?q=is%3Apr+is%3Aclosed
Jammy2211/PyAutoArray#222
NaN Handling
Improvements to NaN handling removing certain failures during lens modeling:
JAX Interferometry More Speed Up + Shapelets
This code speeds up lens modeling with interferometers even more, with uv-plane lens modeling on a HPC GPU taking mere hours for datasets with 100,000,000+ visibilities and extremely high resolution, and run times under 30 minutes for more modest datasets.
This release also includes mature JAX support for shapelets thanks to @Chocologism.
Interferometer Speed Up
https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/interferometer
The fast interferometry JAX GPU implementation (Jammy2211/PyAutoArray#201) uses a preload of the NUFFT in order to compute the curvature_matrix.
This curvature_preload matrix is computed once, before lens modeling, and reused throughout lens modeling. For high number of visibilities and resolution real space mask, this calculation can take minutes or hours on a CPU.
The previous PR did not convert this calculation to JAX or run on GPU.
This pull request makes the W-Tilde curvature preload computation support JAX and GPU, with profiling suggeting at least x100 speed up for high resolution datasets, with the calculation taking under a minute for the highest resolution / visibilities ALMA datasets tested.
It also includes utilities for safely saving and loading precomputed data, which check metadata to ensure the loaded data matches the data being analysed.
Shapelets
Full JAX support for elliptical polar and Cartesian shapelets, with the elliptical power shapelet the recommended default!
JAX Interferometry
Efficient and scalable implementation of pixelized source reconstructions for interferometer analysis using JAX.
This PR introduces a new approach to source reconstructions for interferometer data that fully exploits the symmetries and sparsity of the non-uniform fast transformation.
A high level summary of the implementation is:
-
Pixelized source reconstructions are performed in a way whereby the run time and amount of VRAM used is independent of the number of visibilities.
-
Lens modeling run times are fast, with a 1+ million visibility ALMA dataset being modeled in around 1 hour on a GPU!
-
Other improvements to interferometer analysis and a significant portion of support documentation and examples on the
autolens_workspaceare now provided.
Whilst a quantitative comparison has not yet been performed, my intuition is that this code runs significantly faster than the previous PyAutoLens interferometer modeling and the Powell et al implementation.
Checkout the interferometer package of the autolens_workspace for a complete run through of how to use JAX GPU interferometer analysis!
https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/interferometer
Delaunay JAX
The adaptive Delaunay mesh using a Hilbert image-mesh now supports fully JAX'd likelihood functions running on GPU, which was disabled in previous releases.
The Delaunay mesh itself is not computed on GPU, but CPU, via a JAX pure_callback. Full details are provided below, but this does not impact significantly on performance:
JAX improvements + Fast CPU Pixelizations support + Delaunay
This release continues to build stability for JAX + GPU support:
https://github.com/Jammy2211/PyAutoLens/releases/tag/2025.11.18.1
This includes many fixes to small errors and bugs, for example more light and mass profiles support JAX after small fixes to the source code.
Fast CPU Pixelizations
Before this release, pixelized source reconstructions could only be computed using JAX, either via GPU or CPU.
There were two important factors in run time:
1. GPU VRAM Limitations
JAX only provides significant acceleration on GPUs with large VRAM (≥16 GB).
To avoid excessive VRAM usage, examples often restrict pixelization meshes (e.g. 20 × 20).
On consumer GPUs with limited memory, JAX may be slower than CPU execution.
2. Sparse Matrix Performance
Pixelized source reconstructions require operations on very large, highly sparse matrices.
- JAX currently lacks sparse-matrix support and must compute using dense matrices, which scale poorly.
This release restore support for PyAutoLens’s previous CPU implementation (via numba) which fully exploits sparsity, providing large speed gains at high image resolution (e.g. pixel_scales <= 0.03).
CPU execution can outperform JAX, even on powerful GPUs, for high-resolution datasets or when many CPU cores are used.
Development is actively working on how to get better performance from JAX that exploits sparsity on GPU, but this is proving to be a very challenging problem.
Delaunay
Support for the Delaunay mesh which was the main pixelized source reconstruction has been restored in this release, albeit it only currently works using the numba implementation above and therefore only supports CPU.
Development is actively working on having JAX support for Delaunay source reconstructions, with this expected to be available in the short-term.
PyAutoLens JAX GPU Stability
PyAutoLens JAX Stability
The source code no longer imports JAX or uses JAX with user instruction, meaning all calculations use regular numpy.
JAX is imported and used by Analysis objects when lens modeling begins, ensuring that fast lens modeling using GPUs is always performed by default.
The design of PyAutoLens will build on this, whereby to perform more general lensing calculations users will perform JAX jitting and computation themselves. The docs and guides illustrating this are not written yet, but normal numpy run times are ok for most use cases.
Workspace Restructure
The workspace has been restructured such that the core packages are now the dataset types (imaging, interferometer, etc.):
https://github.com/Jammy2211/autolens_workspace
GPU Modeling Examples
The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:
-
imaging/start_here.ipynb : Galaxy scale strong lenses observed with CCD imaging (e.g. Hubble, James Webb).
-
interferometer/start_here.ipynb : Galaxy scale strong lenses observed with interferometer data (e.g. ALMA).
-
point_source/start_here.ipynb: Galaxy scale strong lenses with a lensed point source (e.g. lensed quasars).
-
group/start_here.ipynb : Group scale strong lenses where there are 2-10 lens galaxies.
PyAutoLens JAX Stability Pull Requests
These are described fully in the following two PRs:
Large refactor which passes the numpy or jax numpy import through the code as xp.
This means that no jax arrays are created inside the source code by default, with all calculations default to Numpy, giving the following benefits:
Unit tests and general code use runs faster as it removes JAX overheads.
Numba support for efficient CPU use can be easily retained as no JAX array mixing.
Less ambiguity in sections of code which dont play nice with JAX arrays (e.g. visualization).
Will allow for an easier more explicit user interface where users JAX jit functions themselves and pass the namespace. to get fast run times.
A recent PR on the child projects made JAX optional for likelihood functions, whereby users pass the JAX namespace as the variable `xp` through the source code.
This PR makes JAX optional at the highest level (e.g. `PyAutoConf` and `PyAutoFit`), including:
- For a non-linear search to use JAX, the `use_jax` input must be passed as `True` to the `Analysis` object.
- The non-linear search will internally work out if it supports JAX natively. This will ultimately have behavior where, for example, if gradients are used it uses `jax.grad`, if not it uses `jax.jit`, and if batching is support `jax.vmap`.
- Currently only `Nautilus` uses the `Analysis.use_jax` attribute to set up a `jax.vmap`.
There are few hacky unclean bits in the autofit model composition where it determines whether to use JAX based on input type. A more thorough consideration of how JAX should work in autofit will be performed in the future.
PyAutoLens-JAX GPU
UPDATE: Latest JAX version is now 2025.11.5.1
This release marks the completion of two years work implementing JAX (https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html) in PyAutoLens.
With JAX, any lens modeling analysis can be run on GPU, with speed up of ~x50 or more for all lens modeling.
Core Release
The core PyAutoLens API does not change significantly, however existing users redownload the new autolens workspace, which has new configs and examples:
https://github.com/Jammy2211/autolens_workspace
New user should checkout the start_here.ipynb notebook, which can be read via a Google Colab by clicking the hyperlink.
GPU Modeling Examples
The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:
-
start_here_imaging.ipynb: Galaxy-scale strong lenses observed with CCD imaging (e.g. Hubble, James Webb).
-
start_here_interferometer.ipynb: Galaxy scale strong lenses observed with interferometer data (e.g. ALMA).
-
start_here_point_source.ipynb: Galaxy scale strong lenses with a lensed point source (e.g. lensed quasars).
-
start_here_group.ipynb: Group scale strong lenses where there are 2-10 lens galaxies.
-
start_here_multi_wavelength.ipynb: Model multiple images (different wavelength imaging, imaging + interferometer) simultaneously.
Performance Of Other Features
-
Pixelized sources run ~x5 - x20 faster on modern HPC GPU clusters, with lens modeling times typically ~10 - 20 minutes. Pixelized source performance depends on the available GPU VRAM. In November 2025 a release will make GPU performance of pixelized sources for all GPU hardware approach < 10 minute lens models.
-
Interferometer with many Visibilities: Above ~ 100,000 visibilities interferometer performance suffers significant slow down. **In December 2025 a new release will make all interferometer modeling efficient irrespective of the number of visibilities.
-
CPU Performance: For pixelized sources CPU performance is worse than the previous PyAutoLens, as JAX is not optimized for CPUs. A future release will restore performance to be on par with previous versions, but users seeking to perform pixelized source modeling without GPU may wish to use the previous PyAutoLens.
Strong Lens Galaxy Clusters
This release can perform strong lens cluster calculations and lens modeling on GPU. For those familiar with cluster lensing, this includes performing an image-plane chi-squared multiple image calculation for clusters with over 100s of cluster members, with full support for multi-plane ray tracing of the entire cluster!
Initial profiling shows it runs 50 or more times faster than other strong lens cluster codes run on CPU. Documentation and examples for cluster modeling are actively being developed but not yet mature. You can find the most up to date examples at the following links:
https://github.com/Jammy2211/autolens_workspace/blob/release/start_here_cluster.ipynb
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/simulators/cluster
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/modeling/cluster
May 2025
- Results workflow API, which generates .csv, .png and .fits files of large libraries of results for quick and efficient inspection:
https://github.com/Jammy2211/autolens_workspace/tree/main/notebooks/results/workflow
-
Visualization now outputs .fits files corresponding to each subplot, which more concisely contain all information of a fit and are used by the above workflow API.
-
Visualization Simplified, removing customization of individual image outputs.
-
Remove Analysis summing API, replacing all dataset combinations with
AnalysisFactorandFactorGraphModelAPI used for graphical modeling:
-
Pixelized source reconstruction output as a .csv file which can be loaded and interpolated for better source science analysis.
-
Double source plane lens modeling now outputs individual subplot_fit for each plane.
-
Latent variable API bug fixes and now used in some test example scripts.
January 2025
The main updates are visualization of Delaunay mesh's using Delaunah triangles and a significant refactoring of over sampling, with the primary motivation to make the code much less complex for the ongoing JAX implementation.
There have also been more improvements to point source modeling, including JAX functionality, which will be documented fully in the near future.
What's Changed
- Feature/disable noise by @Jammy2211 in #324
- feature/delaunay_visual by @Jammy2211 in #323
- feature/inversion_noise_map by @Jammy2211 in #325
- feature/positions_lh_mass_centre by @Jammy2211 in #326
- feature/triangle array typing by @rhayes777 in #328
- feature/array testing by @rhayes777 in #327
- Feature/over sampling refactor by @Jammy2211 in #332
- remove max containing size from solver by @rhayes777 in #329
- feature/andrew implementation by @rhayes777 in #331
Full Changelog: 2024.11.13.2...2025.1.18.7
November 2024 update
Small bug fixes and optimizations for Euclid lens modeling pipeline.