Skip to content

PEtab SciML test suite - test_net and (some) test_ude cases#2947

Merged
BSnelling merged 20 commits intojax_scimlfrom
bes/jax_sciml
Oct 14, 2025
Merged

PEtab SciML test suite - test_net and (some) test_ude cases#2947
BSnelling merged 20 commits intojax_scimlfrom
bes/jax_sciml

Conversation

@BSnelling
Copy link
Collaborator

Implements support for petab sciml test suite: https://github.com/sebapersson/petab_sciml_testsuite/tree/main

Includes support for all test_net cases and a subset of test_ude. Test cases with frozen layers and networks in the observable formulae are not yet implemented.

@BSnelling BSnelling requested a review from a team as a code owner September 2, 2025 13:33
@BSnelling BSnelling marked this pull request as draft September 2, 2025 13:39
@codecov
Copy link

codecov bot commented Sep 2, 2025

Codecov Report

❌ Patch coverage is 99.01961% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 17.87%. Comparing base (f1ece15) to head (23227aa).

Files with missing lines Patch % Lines
python/sdist/amici/jax/petab.py 97.95% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (f1ece15) and HEAD (23227aa). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (f1ece15) HEAD (23227aa)
python 3 0
cpp_python 1 0
Additional details and impacted files

Impacted file tree graph

@@              Coverage Diff               @@
##           jax_sciml    #2947       +/-   ##
==============================================
- Coverage      41.44%   17.87%   -23.57%     
==============================================
  Files            303      104      -199     
  Lines          19945    16072     -3873     
  Branches        1501     1412       -89     
==============================================
- Hits            8266     2873     -5393     
- Misses         11654    13199     +1545     
+ Partials          25        0       -25     
Flag Coverage Δ
cpp_python ?
petab_sciml 14.01% <98.03%> (?)
python ?
sbmlsuite-jax 33.91% <14.70%> (+1.55%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/sdist/amici/de_model.py 55.65% <100.00%> (-27.57%) ⬇️
python/sdist/amici/jax/jaxcodeprinter.py 90.32% <100.00%> (+11.01%) ⬆️
python/sdist/amici/jax/model.py 68.00% <100.00%> (+3.46%) ⬆️
python/sdist/amici/jax/nn.py 97.61% <100.00%> (+83.92%) ⬆️
python/sdist/amici/jax/ode_export.py 88.33% <ø> (+5.00%) ⬆️
python/sdist/amici/petab/parameter_mapping.py 44.81% <100.00%> (-19.23%) ⬇️
python/sdist/amici/petab/petab_import.py 92.30% <100.00%> (-7.70%) ⬇️
python/sdist/amici/jax/petab.py 73.73% <97.95%> (+55.27%) ⬆️

... and 268 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

][:],
dtype=jnp.float64,
),
) # ?? hardcoded dtype not ideal ?? could infer from env somehow ??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it really necessary to set dtype here? Usually jax infers float precision from https://docs.jax.dev/en/latest/config_options.html. Might be necessary to cast this as numpy array first if conversion from hdf5 is the problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried casting to a numpy array but the array persisted as float32s. I've defined the dtype based on the current jax config settings which I think is better than hard coding.

petab.NOMINAL_VALUE,
],
)
if "input"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check seems a bit too unspecific. I think you want to construct a sequence of petab id's that are mapped to $nnId.inputs{[$inputArgumentIndex]{[$inputIndex]}}?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this. The complication is that the values here could be pulled from a nominal value in the parameter table or from a value in the conditions table, depending on whether the id appears in the parameters table. That's my understanding anyway.

dfs.append(df_sc)
return pd.concat(dfs).sort_index()

def apply_grad_filter(problem: JAXProblem,):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great solution! Only thing that I am a bit worried about is that in the current implementation stop_gradient is only applied when calling problem methods in the context of run_simulations, which may lead to confusion when trying to compute gradients outside of that context. My interpretation of the petab problem definition is that setting estimate=0 means that gradient computation is permanently disabled and we should apply apply_grad_filter during JaxProblem instantiation.

@FFroehlich
Copy link
Member

just checking test failures:

  • Notebook tests, also fails on the base branch albeit for a different reasons, but this is not related to changes here
  • mac os, this looks like an issue with PRs from a fork
  • doc tests: also problem in base branch, h5py is missing from doc requirements
  • sbml jax: unrelated, also failing in base branch

@FFroehlich
Copy link
Member

  • mac os, this looks like an issue with PRs from a fork

this is probably not related to failures from forks, but rather CMAKE: #2949 (review)

@FFroehlich
Copy link
Member

just updated the base branch (hopefully without messing up any of the merge conflicts), this should hopefully fix the failing mac tests

@BSnelling BSnelling marked this pull request as ready for review October 14, 2025 12:40
if len(nn_input_arrays) > 0:
for net_id in model_pars:
for input in model.nns[net_id].inputs:
input_array = dict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit unwieldy could we simplify this?

self._parameter_mappings[sc].map_sim_var[p], sc
)
if p
in set(self._parameter_mappings[sc].map_sim_var.keys())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite the check I would expect here (rather something with condition_df/mapping_df), but fine to keep as is for now as we will anyways have to change the whole parameter mapping setup for v2.

@BSnelling BSnelling merged commit 2641a80 into jax_sciml Oct 14, 2025
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants