Skip to content

Conversation

@fromseto
Copy link
Owner

This commit builds upon the JAX integration by adding JAX-based Hessian computation for both MFAModel and InstMFAModel. It also includes the correction to Fitter.solve() from the previous iteration.

Key changes:

  • MFAModel and InstMFAModel now compute the Hessian using jax.hessian on their respective core JAX objective functions when the JAX pathway is active. These are JIT-compiled.
  • The JAX-computed Hessian is provided to SciPy optimizers if the solver is not 'slsqp' or 'ralg'.
  • Corrected Fitter.solve() in fit.py to properly pass JAX usage flags.

Existing functionality:

  • JAX-based objective and gradient for MFAModel (steady-state).
  • Structural JAX support for InstMFAModel (instationary), with the core instationary JAX calculation (core_calculate_inst_mdvs_jax) still being a non-functional placeholder.
  • Data preparation in Calculator and Fitter/InstFitter for JAX pathways.

Testing:

  • Test script execution remains blocked by environment issues. Full validation of JAX steady-state path (including Hessian) and instationary path structure could not be completed.

Further work:

  • Implement core_calculate_inst_mdvs_jax for instationary models.
  • Thoroughly test and benchmark all JAX pathways.

This commit builds upon the JAX integration by adding JAX-based Hessian
computation for both MFAModel and InstMFAModel. It also includes the
correction to Fitter.solve() from the previous iteration.

Key changes:
- MFAModel and InstMFAModel now compute the Hessian using `jax.hessian`
  on their respective core JAX objective functions when the JAX pathway
  is active. These are JIT-compiled.
- The JAX-computed Hessian is provided to SciPy optimizers if the solver
  is not 'slsqp' or 'ralg'.
- Corrected Fitter.solve() in fit.py to properly pass JAX usage flags.

Existing functionality:
- JAX-based objective and gradient for MFAModel (steady-state).
- Structural JAX support for InstMFAModel (instationary), with the
  core instationary JAX calculation (`core_calculate_inst_mdvs_jax`)
  still being a non-functional placeholder.
- Data preparation in Calculator and Fitter/InstFitter for JAX pathways.

Testing:
- Test script execution remains blocked by environment issues.
  Full validation of JAX steady-state path (including Hessian) and
  instationary path structure could not be completed.

Further work:
- Implement `core_calculate_inst_mdvs_jax` for instationary models.
- Thoroughly test and benchmark all JAX pathways.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants