Skip to content

Add JAX JIT Caching for TN Contraction#41

Draft
s-mandra wants to merge 1 commit intomainfrom
jax-fix
Draft

Add JAX JIT Caching for TN Contraction#41
s-mandra wants to merge 1 commit intomainfrom
jax-fix

Conversation

@s-mandra
Copy link
Collaborator

Enhances the contract function to support JAX's Just-In-Time (JIT) compilation when the 'jax' backend is used.

Key changes:

  • Refactored the core contraction logic into _contraction_core to enable JIT compatibility.
  • Implemented a caching mechanism (__JAX_CACHE__) for JIT-compiled contraction cores to improve performance on repeated structures.
  • Added index re-mapping to integers within contract to ensure compatibility with JAX JIT requirements.
  • Updated type hints and return logic to handle cases with and without explicit tensor arrays.

Enhances the `contract` function to support JAX's Just-In-Time (JIT)
compilation when the 'jax' backend is used.

Key changes:
- Refactored the core contraction logic into `_contraction_core` to
  enable JIT compatibility.
- Implemented a caching mechanism (`__JAX_CACHE__`) for JIT-compiled
  contraction cores to improve performance on repeated structures.
- Added index re-mapping to integers within `contract` to ensure
  compatibility with JAX JIT requirements.
- Updated type hints and return logic to handle cases with and
  without explicit tensor arrays.
@s-mandra s-mandra self-assigned this Feb 16, 2026
@s-mandra s-mandra added the enhancement New feature or request label Feb 16, 2026
@s-mandra s-mandra linked an issue Feb 16, 2026 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Sampling is Slow When contraction_backend='jax'

1 participant