Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions .github/workflows/test_petab_sciml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ on:
push:
branches:
- develop
- master
- main
pull_request:
branches:
- master
- main
- develop
- jax_sciml
merge_group:
workflow_dispatch:

Expand All @@ -33,6 +34,7 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 20
submodules: recursive

- name: Install apt dependencies
uses: ./.github/actions/install-apt-dependencies
Expand All @@ -59,26 +61,26 @@ jobs:
- name: Download and install PEtab SciML
run: |
source ./venv/bin/activate \
&& python -m pip install git+https://github.com/sebapersson/petab_sciml.git@unify_data#subdirectory=src/python \
&& python -m pip install git+https://github.com/sebapersson/petab_sciml.git@main#subdirectory=src/python \


- name: Install petab
run: |
source ./venv/bin/activate \
&& python3 -m pip uninstall -y petab \
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@develop \
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@sciml \

- name: Run PEtab SciML testsuite
run: |
source ./venv/bin/activate \
&& pytest --cov-report=xml:coverage.xml \
--cov=./ tests/sciml/test_sciml.py
&& pytest --cov-report=xml:coverage_petab_sciml.xml \
--cov=amici tests/sciml/test_sciml.py

- name: Codecov
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: coverage.xml
flags: petab
file: coverage_petab_sciml.xml
flags: petab_sciml
fail_ci_if_error: true
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import amici
import pandas as pd # noqa: F401
import sympy as sp # noqa: F401
import warnings


def install_doxygen():
Expand Down Expand Up @@ -364,6 +365,10 @@ def install_doxygen():
"ExpDataPtrVector": ":class:`amici.amici.ExpData`",
}

# TODO: alias for forward type definition, remove after release of petab_sciml
autodoc_type_aliases = {
"NNModel": "petab_sciml.NNModel",
}

def process_docstring(app, what, name, obj, options, lines):
# only apply in the amici.amici module
Expand Down
41 changes: 29 additions & 12 deletions doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@
"cell_type": "markdown",
"id": "415962751301c64a",
"metadata": {},
"source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results."
"source": [
"This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results."
]
},
{
"cell_type": "code",
Expand All @@ -94,7 +96,9 @@
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Access the results for the specified condition\n",
"results[simulation_condition]"
"ic = results[\"simulation_conditions\"].index(simulation_condition)\n",
"print(\"llh: \", results[\"llh\"][ic])\n",
"print(\"state variables: \", results[\"x\"][ic, :])"
]
},
{
Expand Down Expand Up @@ -129,7 +133,9 @@
"cell_type": "markdown",
"id": "fe4d3b40ee3efdf2",
"metadata": {},
"source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories."
"source": [
"Success! The simulation completed successfully, and we can now plot the resulting state trajectories."
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -180,7 +186,9 @@
"cell_type": "markdown",
"id": "4fa97c33719c2277",
"metadata": {},
"source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all."
"source": [
"`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all."
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -283,7 +291,9 @@
"cell_type": "markdown",
"id": "dc9bc07cde00a926",
"metadata": {},
"source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`."
"source": [
"Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`."
]
},
{
"cell_type": "code",
Expand All @@ -302,7 +312,9 @@
"cell_type": "markdown",
"id": "851c3ec94cb5d086",
"metadata": {},
"source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`."
"source": [
"Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`."
]
},
{
"cell_type": "code",
Expand All @@ -318,7 +330,9 @@
"cell_type": "markdown",
"id": "375b835fecc5a022",
"metadata": {},
"source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`."
"source": [
"Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`."
]
},
{
"cell_type": "code",
Expand All @@ -334,22 +348,26 @@
"cell_type": "markdown",
"id": "8eb7cc3db510c826",
"metadata": {},
"source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out."
"source": [
"Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out."
]
},
{
"cell_type": "code",
"id": "3badd4402cf6b8c6",
"metadata": {},
"outputs": [],
"source": [
"grad._measurements[simulation_condition]"
"grad._my[ic, :]"
]
},
{
"cell_type": "markdown",
"id": "58eb04393a1463d",
"metadata": {},
"source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation."
"source": [
"However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation."
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -671,8 +689,7 @@
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
Expand Down
3 changes: 3 additions & 0 deletions doc/rtd_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ setuptools>=67.7.2
# https://github.com/pysb/pysb/pull/599
# for building the documentation, we don't care whether this fully works
git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af
# For forward type definition in generate_equinox
git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python
matplotlib>=3.7.1
optax
nbsphinx
Expand All @@ -16,6 +18,7 @@ sphinx_rtd_theme>=1.2.0
petab[vis]>=0.2.0
sphinx-autodoc-typehints
ipython>=8.13.2
h5py>=3.14.0
breathe>=4.35.0
exhale>=0.3.7
-e git+https://github.com/mithro/sphinx-contrib-mithro#egg=sphinx-contrib-exhale-multiproject&subdirectory=sphinx-contrib-exhale-multiproject
Expand Down
45 changes: 42 additions & 3 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2564,6 +2564,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
hybridization information
"""
added_expressions = False
orig_obs = tuple([s.get_id() for s in self._observables])
for net_id, net in hybridization.items():
if net["static"]:
continue # do not integrate into ODEs, handle in amici.jax.petab
Expand Down Expand Up @@ -2595,7 +2596,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
)

outputs = {
out_var: comp
out_var: {"comp": comp, "ind": net["output_vars"][out_var]}
for comp in self._components
if (out_var := str(comp.get_id())) in net["output_vars"]
# TODO: SYNTAX NEEDS to CHANGE
Expand All @@ -2606,7 +2607,9 @@ def _process_hybridization(self, hybridization: dict) -> None:
raise ValueError(
f"Could not find all output variables for neural network {net_id}"
)
for iout, (out_var, comp) in enumerate(outputs.items()):

for out_var, parts in outputs.items():
comp = parts["comp"]
# remove output from model components
if isinstance(comp, Parameter):
self._parameters.remove(comp)
Expand All @@ -2621,7 +2624,7 @@ def _process_hybridization(self, hybridization: dict) -> None:

# generate dummy Function
out_val = sp.Function(net_id)(
*[input.get_id() for input in inputs], iout
*[input.get_id() for input in inputs], parts["ind"]
)

# add to the model
Expand All @@ -2642,6 +2645,42 @@ def _process_hybridization(self, hybridization: dict) -> None:
)
added_expressions = True

observables = {
ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]}
for comp in self._components
if (ob_var := str(comp.get_id())) in net["observable_vars"]
# # TODO: SYNTAX NEEDS to CHANGE
# or (ob_var := str(comp.get_id()) + "_dot")
# in net["observable_vars"]
}
if len(observables.keys()) != len(net["observable_vars"]):
raise ValueError(
f"Could not find all observable variables for neural network {net_id}"
)

for ob_var, parts in observables.items():
comp = parts["comp"]
if isinstance(comp, Observable):
self._observables.remove(comp)
else:
raise ValueError(
f"{comp.get_name()} ({type(comp)}) is not an observable."
)
out_val = sp.Function(net_id)(
*[input.get_id() for input in inputs], parts["ind"]
)
# add to the model
self.add_component(
Observable(
identifier=comp.get_id(),
name=net_id,
value=out_val,
)
)

new_order = [orig_obs.index(s.get_id()) for s in self._observables]
self._observables = [self._observables[i] for i in new_order]

if added_expressions:
# toposort expressions
w_sorted = toposort_symbols(
Expand Down
8 changes: 6 additions & 2 deletions python/sdist/amici/jax/jaxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import sympy as sp
from sympy.printing.numpy import NumPyPrinter
from sympy.core.function import UndefinedFunction


def _jnp_array_str(array) -> str:
Expand Down Expand Up @@ -42,8 +43,11 @@ def _print_Mul(self, expr: sp.Expr) -> str:
return super()._print_Mul(expr)
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"

def _print_Function(self, expr):
return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]"
def _print_Function(self, expr: sp.Expr) -> str:
if isinstance(expr.func, UndefinedFunction):
return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]"
else:
return super()._print_Function(expr)

def _print_Max(self, expr: sp.Expr) -> str:
"""
Expand Down
11 changes: 11 additions & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ def simulate_condition(
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]),
init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
Expand Down Expand Up @@ -588,6 +590,10 @@ def simulate_condition(
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param init_override:
override model input e.g. with neural net outputs. If not provided, the inputs are not overridden.
:param init_override_mask:
mask for input override. If `True`, the corresponding input is replaced with value from init_override.
:param ts_mask:
mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of
the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.
Expand All @@ -602,6 +608,11 @@ def simulate_condition(

if x_preeq.shape[0]:
x = x_preeq
elif init_override.shape[0]:
x_def = self._x0(t0, p)
x = jnp.squeeze(
jnp.where(init_override_mask, init_override, x_def)
)
else:
x = self._x0(t0, p)

Expand Down
Loading
Loading