From 1701edeee1d90dec6885ffec862943d304418664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 28 Jul 2025 14:36:27 +0100 Subject: [PATCH 1/3] Update ExampleJaxPEtab.ipynb --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 374 +++++++----------- 1 file changed, 148 insertions(+), 226 deletions(-) diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 2490754ec9..ed061a40b0 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -26,15 +26,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "c71c96da0da3144a", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:12.220374Z", - "start_time": "2025-01-29T15:49:12.114366Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "from amici.petab.petab_import import import_petab_problem\n", "import petab.v1 as petab\n", @@ -55,7 +48,9 @@ " verbose=False, # no text output\n", " jax=True, # return jax model\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -69,15 +64,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "ccecc9a29acc7b73", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.455948Z", - "start_time": "2025-01-29T15:49:12.224414Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "from amici.jax import JAXProblem, run_simulations\n", "\n", @@ -86,7 +74,9 @@ "\n", "# Run simulations and compute the log-likelihood\n", "llh, results = run_simulations(jax_problem)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -98,19 +88,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "596b86e45e18fe3d", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.469126Z", - "start_time": "2025-01-29T15:49:13.464492Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "# Access the results\n", "results" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -124,15 +109,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "f4f5ff705a3f7402", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.517447Z", - "start_time": "2025-01-29T15:49:13.498128Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "import jax\n", "\n", @@ -143,7 +121,9 @@ "llh, results = run_simulations(jax_problem)\n", "\n", "results" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -155,15 +135,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "72f1ed397105e14a", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.626555Z", - "start_time": "2025-01-29T15:49:13.540193Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", @@ -201,7 +174,9 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -213,19 +188,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "7950774a3e989042", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.640281Z", - "start_time": "2025-01-29T15:49:13.637222Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", "results" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -239,15 +209,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "3d278a3d21e709d", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.690093Z", - "start_time": "2025-01-29T15:49:13.666663Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "from dataclasses import FrozenInstanceError\n", "import jax\n", @@ -265,7 +228,9 @@ " jax_problem.parameters += noise\n", "except FrozenInstanceError as e:\n", " print(\"Error:\", e)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -279,15 +244,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "e47748376059628b", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.810758Z", - "start_time": "2025-01-29T15:49:13.712463Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "# Update the parameters and create a new JAXProblem instance\n", "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", @@ -297,7 +255,9 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -306,27 +266,22 @@ "source": [ "## Computing Gradients\n", "\n", - "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." + "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosystem. JAX offers [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." ] }, { "cell_type": "code", - "execution_count": null, "id": "7033d09cc81b7f69", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:13.824702Z", - "start_time": "2025-01-29T15:49:13.821212Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "try:\n", " # Attempt to compute the gradient of the run_simulations function\n", " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", "except TypeError as e:\n", " print(\"Error:\", e)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -338,21 +293,16 @@ }, { "cell_type": "code", - "execution_count": null, "id": "a6704182200e6438", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:20.085633Z", - "start_time": "2025-01-29T15:49:13.853364Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "import equinox as eqx\n", "\n", "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -364,18 +314,13 @@ }, { "cell_type": "code", - "execution_count": null, "id": "c00c1581d7173d7a", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:20.096400Z", - "start_time": "2025-01-29T15:49:20.093962Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "grad.parameters" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -387,18 +332,13 @@ }, { "cell_type": "code", - "execution_count": null, "id": "f7c17f7459d0151f", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:20.123274Z", - "start_time": "2025-01-29T15:49:20.120144Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "grad" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -410,18 +350,13 @@ }, { "cell_type": "code", - "execution_count": null, "id": "3badd4402cf6b8c6", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:20.151355Z", - "start_time": "2025-01-29T15:49:20.148297Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "grad._my" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -433,15 +368,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "1a91aff44b93157", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:21.966714Z", - "start_time": "2025-01-29T15:49:20.188760Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "import jax.numpy as jnp\n", "import diffrax\n", @@ -490,7 +418,56 @@ "# Compute the gradient with respect to `ts_dyn`\n", "g = grad_ts_dyn(ts_dyn)\n", "g" - ] + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "19ca88c8900584ce", + "metadata": {}, + "source": "## Model training" + }, + { + "cell_type": "markdown", + "id": "7f99c046d7d4e225", + "metadata": {}, + "source": "This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam." + }, + { + "cell_type": "code", + "id": "24cf05f7866bd250", + "metadata": {}, + "source": [ + "from optax import adam\n", + "\n", + "# define loss function\n", + "loss = eqx.filter_value_and_grad(run_simulations, has_aux=True)\n", + "\n", + "# initialise adam\n", + "optim = adam(0.01)\n", + "# eqx.partition is necessary here to only initialize the optimizer for array variables\n", + "param, static = eqx.partition(jax_problem, eqx.is_array)\n", + "opt_state = optim.init(param)\n", + "\n", + "\n", + "# define update function\n", + "@eqx.filter_jit\n", + "def make_step(problem, opt_state):\n", + " current_loss, grads = loss(problem)\n", + " updates, opt_state = optim.update(grads, opt_state)\n", + " model = eqx.apply_updates(problem, updates)\n", + " return current_loss, model, opt_state\n", + "\n", + "\n", + "# run 5 optimisation steps\n", + "for step in range(5):\n", + " current_loss, jax_problem, opt_state = make_step(jax_problem, opt_state)\n", + " current_loss = current_loss[0].item()\n", + " print(f\"step={step}, loss={current_loss}\")" + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -504,15 +481,8 @@ }, { "cell_type": "code", - "execution_count": null, "id": "58ebdc110ea7457e", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:22.363492Z", - "start_time": "2025-01-29T15:49:22.028899Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "from time import time\n", "\n", @@ -521,19 +491,14 @@ "\n", "# Define a JIT-compiled gradient function with auxiliary outputs\n", "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "e1242075f7e0faf", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:30.839352Z", - "start_time": "2025-01-29T15:49:22.371391Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "# Measure the time taken for the first function call (including compilation)\n", "start = time()\n", @@ -544,19 +509,14 @@ "start = time()\n", "gradfun(jax_problem)\n", "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "27181f367ccb1817", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:32.125487Z", - "start_time": "2025-01-29T15:49:30.847973Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "%%timeit\n", "run_simulations(\n", @@ -569,19 +529,14 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "5b8d3a6162a3ae55", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:37.566080Z", - "start_time": "2025-01-29T15:49:32.193598Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "%%timeit \n", "gradfun(\n", @@ -594,19 +549,14 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "d733a450635a749b", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:52.877239Z", - "start_time": "2025-01-29T15:49:37.633290Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "from amici.petab import simulate_petab\n", "import amici\n", @@ -627,35 +577,25 @@ "problem_parameters = dict(\n", " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "413ed7c60b2cf4be", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:49:52.891165Z", - "start_time": "2025-01-29T15:49:52.889250Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "# Profile simulation only\n", "solver.setSensitivityOrder(amici.SensitivityOrder.none)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "768fa60e439ca8b4", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:50:06.598838Z", - "start_time": "2025-01-29T15:49:52.902527Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "%%timeit \n", "simulate_petab(\n", @@ -666,36 +606,26 @@ " scaled_parameters=True,\n", " scaled_gradients=True,\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "b8382b0b2b68f49e", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:50:06.660478Z", - "start_time": "2025-01-29T15:50:06.658434Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "# Profile gradient computation using forward sensitivity analysis\n", "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", "solver.setSensitivityMethod(amici.SensitivityMethod.forward)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "3bae1fab8c416122", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:50:22.127188Z", - "start_time": "2025-01-29T15:50:06.673328Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "%%timeit \n", "simulate_petab(\n", @@ -706,36 +636,26 @@ " scaled_parameters=True,\n", " scaled_gradients=True,\n", ")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "71e0358227e1dc74", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:50:22.195899Z", - "start_time": "2025-01-29T15:50:22.193851Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "# Profile gradient computation using adjoint sensitivity analysis\n", "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", "solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "code", - "execution_count": null, "id": "e3cc7971002b6d06", - "metadata": { - "ExecuteTime": { - "end_time": "2025-01-29T15:50:24.178434Z", - "start_time": "2025-01-29T15:50:22.207474Z" - } - }, - "outputs": [], + "metadata": {}, "source": [ "%%timeit \n", "simulate_petab(\n", @@ -746,7 +666,9 @@ " scaled_parameters=True,\n", " scaled_gradients=True,\n", ")" - ] + ], + "outputs": [], + "execution_count": null } ], "metadata": { From 2a734e3aeb95ee2b8c56aa2927856676d0c12d18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 28 Jul 2025 15:23:58 +0100 Subject: [PATCH 2/3] Update rtd_requirements.txt --- doc/rtd_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index b6d906e17a..3dbc195569 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -8,6 +8,7 @@ setuptools>=67.7.2 # for building the documentation, we don't care whether this fully works git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af matplotlib>=3.7.1 +optax nbsphinx nbformat myst-parser From 2eee255ac65681c12bd5588c910fc7111d51c60a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 28 Jul 2025 15:46:53 +0100 Subject: [PATCH 3/3] Update installAmiciSource.sh --- scripts/installAmiciSource.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/installAmiciSource.sh b/scripts/installAmiciSource.sh index 918e6f70f3..807d85dcb7 100755 --- a/scripts/installAmiciSource.sh +++ b/scripts/installAmiciSource.sh @@ -39,6 +39,7 @@ python -m pip install --upgrade pip wheel python -m pip install --upgrade pip setuptools cmake_build_extension==0.6.0 numpy petab swig python -m pip install git+https://github.com/pysb/pysb@master # for SPM with compartments python -m pip install git+https://github.com/patrick-kidger/diffrax@dev # for events with direction +python -m pip install optax # for jax petab notebook AMICI_BUILD_TEMP="${AMICI_PATH}/python/sdist/build/temp" \ python -m pip install --verbose -e "${AMICI_PATH}/python/sdist[petab,test,vis,jax]" --no-build-isolation deactivate