Skip to content

Releases: Jammy2211/PyAutoLens

Pixelization API + Cosmology

26 Feb 19:48

Choose a tag to compare

Simplified JAX Pixelization API

The pixelization API no longer needs the preload arrays to perform modeling in JAX, simplifying the use of pixelized source reconstructions.

For example, the API to set up a rectangular mesh now only needs its shape:

"""
__Mesh Shape__

The `mesh_shape` parameter defines number of pixels used by the rectangular mesh to reconstruct the source,
set below to 28 x 28. 

The `mesh_shape` must be fixed before modeling and cannot be a free parameter of the model, because JAX uses the
mesh shape to define static shaped arrays which use the mesh to reconstruct the source. For a rectangular
mesh, the same number of pixels must be used in the y and x directions.

__Edge Zeroing__

By default, all pixels at the edge of the mesh in the source-plane are forced to solutions of zero brightness by 
the linear algebra solver. This prevents unphysical solutions where pixels at the edge of the mesh reconstruct 
bright surface brightnesses, often because they fit residuals from the lens light subtraction.

For a rectangular mesh, the source code computes edge pixels internally using the known
pixels at the edge of the mesh. 
"""
mesh_pixels_yx = 28
mesh_shape = (mesh_pixels_yx, mesh_pixels_yx)

Pull Request: https://github.com/Jammy2211/PyAutoLens/pulls?q=is%3Apr+is%3Aclosed
Example: https://github.com/Jammy2211/autolens_workspace/blob/release/notebooks/imaging/features/pixelization/modeling.ipynb

Cosmology JAX

Astropy Cosmology imports all removed, with all Cosmology calculations now using in built PyAutoLens functions which fully support JAX:

#382

JAX cored and elliptical NFW

Both now fully supported:

Jammy2211/PyAutoGalaxy#279
Jammy2211/PyAutoGalaxy#280

Adaptive Matern Kernel

Matern kernel with regularization which adapts to source's brightness implemented: Jammy2211/PyAutoArray#214

Tests on rectangular mesh show this is the best performing regularization scheme for the rectangular mesh.

Mesh Refactor:

The above API changes are part of a larger refactoring of the pixliezed source module in autoarray, with the mesh and mapper modules now much cleaner in terms of code:

Pull Request: #388

Convolution Refactor

The PSF convolution h as been moved to an operator module and also cleaned up, including better internal padding of data when light outside the mask blurs in.

Pull Requests:

github.com/Jammy2211/PyAutoArray/pulls?q=is%3Apr+is%3Aclosed
Jammy2211/PyAutoArray#222

NaN Handling

Improvements to NaN handling removing certain failures during lens modeling:

#383

JAX Interferometry More Speed Up + Shapelets

22 Jan 09:08

Choose a tag to compare

This code speeds up lens modeling with interferometers even more, with uv-plane lens modeling on a HPC GPU taking mere hours for datasets with 100,000,000+ visibilities and extremely high resolution, and run times under 30 minutes for more modest datasets.

This release also includes mature JAX support for shapelets thanks to @Chocologism.

Interferometer Speed Up

Jammy2211/PyAutoArray#202

https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/interferometer

The fast interferometry JAX GPU implementation (Jammy2211/PyAutoArray#201) uses a preload of the NUFFT in order to compute the curvature_matrix.

This curvature_preload matrix is computed once, before lens modeling, and reused throughout lens modeling. For high number of visibilities and resolution real space mask, this calculation can take minutes or hours on a CPU.

The previous PR did not convert this calculation to JAX or run on GPU.

This pull request makes the W-Tilde curvature preload computation support JAX and GPU, with profiling suggeting at least x100 speed up for high resolution datasets, with the calculation taking under a minute for the highest resolution / visibilities ALMA datasets tested.

It also includes utilities for safely saving and loading precomputed data, which check metadata to ensure the loaded data matches the data being analysed.

Shapelets

Jammy2211/PyAutoGalaxy#259

https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/imaging/features/advanced/shapelets

Full JAX support for elliptical polar and Cartesian shapelets, with the elliptical power shapelet the recommended default!

JAX Interferometry

21 Dec 19:51

Choose a tag to compare

Efficient and scalable implementation of pixelized source reconstructions for interferometer analysis using JAX.

This PR introduces a new approach to source reconstructions for interferometer data that fully exploits the symmetries and sparsity of the non-uniform fast transformation.

A high level summary of the implementation is:

  • Pixelized source reconstructions are performed in a way whereby the run time and amount of VRAM used is independent of the number of visibilities.

  • Lens modeling run times are fast, with a 1+ million visibility ALMA dataset being modeled in around 1 hour on a GPU!

  • Other improvements to interferometer analysis and a significant portion of support documentation and examples on the autolens_workspace are now provided.

Whilst a quantitative comparison has not yet been performed, my intuition is that this code runs significantly faster than the previous PyAutoLens interferometer modeling and the Powell et al implementation.

Checkout the interferometer package of the autolens_workspace for a complete run through of how to use JAX GPU interferometer analysis!

https://github.com/Jammy2211/autolens_workspace/tree/release/notebooks/interferometer

Delaunay JAX

15 Dec 10:10

Choose a tag to compare

The adaptive Delaunay mesh using a Hilbert image-mesh now supports fully JAX'd likelihood functions running on GPU, which was disabled in previous releases.

The Delaunay mesh itself is not computed on GPU, but CPU, via a JAX pure_callback. Full details are provided below, but this does not impact significantly on performance:

Jammy2211/PyAutoArray#199

JAX improvements + Fast CPU Pixelizations support + Delaunay

29 Nov 15:23

Choose a tag to compare

This release continues to build stability for JAX + GPU support:

https://github.com/Jammy2211/PyAutoLens/releases/tag/2025.11.18.1

This includes many fixes to small errors and bugs, for example more light and mass profiles support JAX after small fixes to the source code.

Fast CPU Pixelizations

https://github.com/Jammy2211/autolens_workspace/blob/release/notebooks/imaging/features/pixelization/cpu_fast_modeling.ipynb

Before this release, pixelized source reconstructions could only be computed using JAX, either via GPU or CPU.

There were two important factors in run time:

1. GPU VRAM Limitations

JAX only provides significant acceleration on GPUs with large VRAM (≥16 GB).
To avoid excessive VRAM usage, examples often restrict pixelization meshes (e.g. 20 × 20).
On consumer GPUs with limited memory, JAX may be slower than CPU execution.

2. Sparse Matrix Performance

Pixelized source reconstructions require operations on very large, highly sparse matrices.

  • JAX currently lacks sparse-matrix support and must compute using dense matrices, which scale poorly.

This release restore support for PyAutoLens’s previous CPU implementation (via numba) which fully exploits sparsity, providing large speed gains at high image resolution (e.g. pixel_scales <= 0.03).

CPU execution can outperform JAX, even on powerful GPUs, for high-resolution datasets or when many CPU cores are used.

Development is actively working on how to get better performance from JAX that exploits sparsity on GPU, but this is proving to be a very challenging problem.

Delaunay

https://github.com/Jammy2211/autolens_workspace/blob/release/notebooks/imaging/features/pixelization/delaunay.ipynb

Support for the Delaunay mesh which was the main pixelized source reconstruction has been restored in this release, albeit it only currently works using the numba implementation above and therefore only supports CPU.

Development is actively working on having JAX support for Delaunay source reconstructions, with this expected to be available in the short-term.

PyAutoLens JAX GPU Stability

18 Nov 15:38

Choose a tag to compare

PyAutoLens JAX Stability

The source code no longer imports JAX or uses JAX with user instruction, meaning all calculations use regular numpy.

JAX is imported and used by Analysis objects when lens modeling begins, ensuring that fast lens modeling using GPUs is always performed by default.

The design of PyAutoLens will build on this, whereby to perform more general lensing calculations users will perform JAX jitting and computation themselves. The docs and guides illustrating this are not written yet, but normal numpy run times are ok for most use cases.

Workspace Restructure

The workspace has been restructured such that the core packages are now the dataset types (imaging, interferometer, etc.):

https://github.com/Jammy2211/autolens_workspace

GPU Modeling Examples

The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:

PyAutoLens JAX Stability Pull Requests

These are described fully in the following two PRs:

#371

Large refactor which passes the numpy or jax numpy import through the code as xp.

This means that no jax arrays are created inside the source code by default, with all calculations default to Numpy, giving the following benefits:

Unit tests and general code use runs faster as it removes JAX overheads.
Numba support for efficient CPU use can be easily retained as no JAX array mixing.
Less ambiguity in sections of code which dont play nice with JAX arrays (e.g. visualization).
Will allow for an easier more explicit user interface where users JAX jit functions themselves and pass the namespace. to get fast run times.

#372

A recent PR on the child projects made JAX optional for likelihood functions, whereby users pass the JAX namespace as the variable `xp` through the source  code.

This PR makes JAX optional at the highest level (e.g. `PyAutoConf` and `PyAutoFit`), including:

- For a non-linear search to use JAX, the `use_jax` input must be passed as `True` to the `Analysis` object.
- The non-linear search will internally work out if it supports JAX natively. This will ultimately have behavior where, for example, if gradients are used it uses `jax.grad`, if not it uses `jax.jit`, and if batching is support `jax.vmap`.
- Currently only `Nautilus` uses the `Analysis.use_jax` attribute to set up a `jax.vmap`.

There are few hacky unclean bits in the autofit model composition where it determines whether to use JAX based on input type. A more thorough consideration of how JAX should work in autofit will be performed in the future.

PyAutoLens-JAX GPU

20 Oct 18:27

Choose a tag to compare

UPDATE: Latest JAX version is now 2025.11.5.1

This release marks the completion of two years work implementing JAX (https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html) in PyAutoLens.

With JAX, any lens modeling analysis can be run on GPU, with speed up of ~x50 or more for all lens modeling.

Core Release

The core PyAutoLens API does not change significantly, however existing users redownload the new autolens workspace, which has new configs and examples:

https://github.com/Jammy2211/autolens_workspace

New user should checkout the start_here.ipynb notebook, which can be read via a Google Colab by clicking the hyperlink.

GPU Modeling Examples

The following Juypter Notebooks, which run via Google Colab, illustrate < 10 minute lens modeling for different science cases:

Performance Of Other Features

  • Pixelized sources run ~x5 - x20 faster on modern HPC GPU clusters, with lens modeling times typically ~10 - 20 minutes. Pixelized source performance depends on the available GPU VRAM. In November 2025 a release will make GPU performance of pixelized sources for all GPU hardware approach < 10 minute lens models.

  • Interferometer with many Visibilities: Above ~ 100,000 visibilities interferometer performance suffers significant slow down. **In December 2025 a new release will make all interferometer modeling efficient irrespective of the number of visibilities.

  • CPU Performance: For pixelized sources CPU performance is worse than the previous PyAutoLens, as JAX is not optimized for CPUs. A future release will restore performance to be on par with previous versions, but users seeking to perform pixelized source modeling without GPU may wish to use the previous PyAutoLens.

Strong Lens Galaxy Clusters

This release can perform strong lens cluster calculations and lens modeling on GPU. For those familiar with cluster lensing, this includes performing an image-plane chi-squared multiple image calculation for clusters with over 100s of cluster members, with full support for multi-plane ray tracing of the entire cluster!

Initial profiling shows it runs 50 or more times faster than other strong lens cluster codes run on CPU. Documentation and examples for cluster modeling are actively being developed but not yet mature. You can find the most up to date examples at the following links:

https://github.com/Jammy2211/autolens_workspace/blob/release/start_here_cluster.ipynb
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/simulators/cluster
https://github.com/Jammy2211/autolens_workspace/tree/release/scripts/modeling/cluster

May 2025

07 May 20:44

Choose a tag to compare

  • Results workflow API, which generates .csv, .png and .fits files of large libraries of results for quick and efficient inspection:

https://github.com/Jammy2211/autolens_workspace/tree/main/notebooks/results/workflow

  • Visualization now outputs .fits files corresponding to each subplot, which more concisely contain all information of a fit and are used by the above workflow API.

  • Visualization Simplified, removing customization of individual image outputs.

  • Remove Analysis summing API, replacing all dataset combinations with AnalysisFactor and FactorGraphModel API used for graphical modeling:

https://github.com/Jammy2211/autolens_workspace/blob/main/notebooks/advanced/multi/modeling/start_here.ipynb

  • Pixelized source reconstruction output as a .csv file which can be loaded and interpolated for better source science analysis.

  • Double source plane lens modeling now outputs individual subplot_fit for each plane.

  • Latent variable API bug fixes and now used in some test example scripts.

January 2025

18 Jan 12:46

Choose a tag to compare

The main updates are visualization of Delaunay mesh's using Delaunah triangles and a significant refactoring of over sampling, with the primary motivation to make the code much less complex for the ongoing JAX implementation.

There have also been more improvements to point source modeling, including JAX functionality, which will be documented fully in the near future.

What's Changed

Full Changelog: 2024.11.13.2...2025.1.18.7

November 2024 update

13 Nov 14:00

Choose a tag to compare

Small bug fixes and optimizations for Euclid lens modeling pipeline.