PEtab SciML test suite - test_net and (some) test_ude cases#2947
PEtab SciML test suite - test_net and (some) test_ude cases#2947
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## jax_sciml #2947 +/- ##
==============================================
- Coverage 41.44% 17.87% -23.57%
==============================================
Files 303 104 -199
Lines 19945 16072 -3873
Branches 1501 1412 -89
==============================================
- Hits 8266 2873 -5393
- Misses 11654 13199 +1545
+ Partials 25 0 -25
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
python/sdist/amici/jax/petab.py
Outdated
| ][:], | ||
| dtype=jnp.float64, | ||
| ), | ||
| ) # ?? hardcoded dtype not ideal ?? could infer from env somehow ?? |
There was a problem hiding this comment.
is it really necessary to set dtype here? Usually jax infers float precision from https://docs.jax.dev/en/latest/config_options.html. Might be necessary to cast this as numpy array first if conversion from hdf5 is the problem
There was a problem hiding this comment.
I tried casting to a numpy array but the array persisted as float32s. I've defined the dtype based on the current jax config settings which I think is better than hard coding.
python/sdist/amici/jax/petab.py
Outdated
| petab.NOMINAL_VALUE, | ||
| ], | ||
| ) | ||
| if "input" |
There was a problem hiding this comment.
this check seems a bit too unspecific. I think you want to construct a sequence of petab id's that are mapped to $nnId.inputs{[$inputArgumentIndex]{[$inputIndex]}}?
There was a problem hiding this comment.
I've updated this. The complication is that the values here could be pulled from a nominal value in the parameter table or from a value in the conditions table, depending on whether the id appears in the parameters table. That's my understanding anyway.
python/sdist/amici/jax/petab.py
Outdated
| dfs.append(df_sc) | ||
| return pd.concat(dfs).sort_index() | ||
|
|
||
| def apply_grad_filter(problem: JAXProblem,): |
There was a problem hiding this comment.
This is a great solution! Only thing that I am a bit worried about is that in the current implementation stop_gradient is only applied when calling problem methods in the context of run_simulations, which may lead to confusion when trying to compute gradients outside of that context. My interpretation of the petab problem definition is that setting estimate=0 means that gradient computation is permanently disabled and we should apply apply_grad_filter during JaxProblem instantiation.
|
just checking test failures:
|
this is probably not related to failures from forks, but rather CMAKE: #2949 (review) |
|
just updated the base branch (hopefully without messing up any of the merge conflicts), this should hopefully fix the failing mac tests |
- update sciml testsuite submodule to point at main - fix a eqx LayerNorm deprecation warning - fix string formatting of bool in kwarg - updates to test code driven by updated sciml format
…ludes: - frozen nn layers - nns in observable formulae
generalise frozen layers to networks across system Use stop_grad instead
78d2cb9 to
9e744c7
Compare
python/sdist/amici/jax/petab.py
Outdated
| if len(nn_input_arrays) > 0: | ||
| for net_id in model_pars: | ||
| for input in model.nns[net_id].inputs: | ||
| input_array = dict( |
There was a problem hiding this comment.
bit unwieldy could we simplify this?
| self._parameter_mappings[sc].map_sim_var[p], sc | ||
| ) | ||
| if p | ||
| in set(self._parameter_mappings[sc].map_sim_var.keys()) |
There was a problem hiding this comment.
This isn't quite the check I would expect here (rather something with condition_df/mapping_df), but fine to keep as is for now as we will anyways have to change the whole parameter mapping setup for v2.
Implements support for petab sciml test suite: https://github.com/sebapersson/petab_sciml_testsuite/tree/main
Includes support for all
test_netcases and a subset oftest_ude. Test cases with frozen layers and networks in the observable formulae are not yet implemented.