Conversation
added 23 commits
December 2, 2025 18:57
added 29 commits
December 10, 2025 17:49
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Wraps the
scipy.spatial.Delaunayfunction in a JAX primitive, such that during ajax.jitorjax.vmapit runs on CPU whilst the rest of the likelihood function is computed efficiently on GPU.Many subsequent changes to source code making other aspects the Delaunay (e.g.
AdaptiveBrightnessSplitregularization work with JAX are included.Benchmarking:
Using the Delaunay mesh via a pure_callback, the likelihood evaluation on the HPC GPU is currently ~2× slower than the rectangular mesh for both HST and Euclid datasets.
The dominant additional cost is ~0.025 s per likelihood evaluation, which occurs inside the callback and includes:
construction of the Delaunay triangulation,
locating simplices for the oversampled source-plane grid,
construction of the Voronoi mesh and computation of cell areas for regularization.
All of this work is performed on the CPU, and given that, the overhead is relatively modest: it increases the total runtime by roughly a factor of two.
For both rectangular and Delaunay meshes, there is already an irreducible cost of ~0.025 s per sample from non-sparse linear algebra. As a result, adding an additional ~0.025 s of CPU-side geometry work for the Delaunay mesh does not significantly degrade overall performance.
If and when JAX/GPU linear algebra fully exploits sparsity, the rectangular mesh is expected to become substantially faster (potentially an order of magnitude or more), at which point the Delaunay approach may become comparatively less attractive from a performance standpoint.
There are early indications that the Delaunay mesh may converge in fewer iterations, meaning the effective runtime penalty can be smaller than the raw ~2× per-sample cost, though this needs to be tested on more complex models.
Going forward, I plan to maintain and science-test both the Delaunay and rectangular meshes in parallel. The longer-term goal is to build sufficient confidence in the rectangular implementation to potentially adopt it as the default.