diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml index 98cc735e7a..c01af88739 100644 --- a/.github/workflows/test_petab_sciml.yml +++ b/.github/workflows/test_petab_sciml.yml @@ -3,11 +3,12 @@ on: push: branches: - develop - - master + - main pull_request: branches: - - master + - main - develop + - jax_sciml merge_group: workflow_dispatch: @@ -33,6 +34,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 20 + submodules: recursive - name: Install apt dependencies uses: ./.github/actions/install-apt-dependencies @@ -59,26 +61,26 @@ jobs: - name: Download and install PEtab SciML run: | source ./venv/bin/activate \ - && python -m pip install git+https://github.com/sebapersson/petab_sciml.git@unify_data#subdirectory=src/python \ + && python -m pip install git+https://github.com/sebapersson/petab_sciml.git@main#subdirectory=src/python \ - name: Install petab run: | source ./venv/bin/activate \ && python3 -m pip uninstall -y petab \ - && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@develop \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@sciml \ - name: Run PEtab SciML testsuite run: | source ./venv/bin/activate \ - && pytest --cov-report=xml:coverage.xml \ - --cov=./ tests/sciml/test_sciml.py + && pytest --cov-report=xml:coverage_petab_sciml.xml \ + --cov=amici tests/sciml/test_sciml.py - name: Codecov if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev' uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} - file: coverage.xml - flags: petab + file: coverage_petab_sciml.xml + flags: petab_sciml fail_ci_if_error: true diff --git a/doc/conf.py b/doc/conf.py index 78e8534768..99f1396c0f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -33,6 +33,7 @@ import amici import pandas as pd # noqa: F401 import sympy as sp # noqa: F401 +import warnings def install_doxygen(): @@ -364,6 +365,10 @@ def install_doxygen(): "ExpDataPtrVector": ":class:`amici.amici.ExpData`", } +# TODO: alias for forward type definition, remove after release of petab_sciml +autodoc_type_aliases = { + "NNModel": "petab_sciml.NNModel", +} def process_docstring(app, what, name, obj, options, lines): # only apply in the amici.amici module diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 983e139237..2e16bf7d50 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -82,7 +82,9 @@ "cell_type": "markdown", "id": "415962751301c64a", "metadata": {}, - "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." + "source": [ + "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." + ] }, { "cell_type": "code", @@ -94,7 +96,9 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Access the results for the specified condition\n", - "results[simulation_condition]" + "ic = results[\"simulation_conditions\"].index(simulation_condition)\n", + "print(\"llh: \", results[\"llh\"][ic])\n", + "print(\"state variables: \", results[\"x\"][ic, :])" ] }, { @@ -129,7 +133,9 @@ "cell_type": "markdown", "id": "fe4d3b40ee3efdf2", "metadata": {}, - "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." + "source": [ + "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." + ] }, { "cell_type": "code", @@ -180,7 +186,9 @@ "cell_type": "markdown", "id": "4fa97c33719c2277", "metadata": {}, - "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." + "source": [ + "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." + ] }, { "cell_type": "code", @@ -283,7 +291,9 @@ "cell_type": "markdown", "id": "dc9bc07cde00a926", "metadata": {}, - "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." + "source": [ + "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." + ] }, { "cell_type": "code", @@ -302,7 +312,9 @@ "cell_type": "markdown", "id": "851c3ec94cb5d086", "metadata": {}, - "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." + "source": [ + "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." + ] }, { "cell_type": "code", @@ -318,7 +330,9 @@ "cell_type": "markdown", "id": "375b835fecc5a022", "metadata": {}, - "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." + "source": [ + "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." + ] }, { "cell_type": "code", @@ -334,7 +348,9 @@ "cell_type": "markdown", "id": "8eb7cc3db510c826", "metadata": {}, - "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." + "source": [ + "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." + ] }, { "cell_type": "code", @@ -342,14 +358,16 @@ "metadata": {}, "outputs": [], "source": [ - "grad._measurements[simulation_condition]" + "grad._my[ic, :]" ] }, { "cell_type": "markdown", "id": "58eb04393a1463d", "metadata": {}, - "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." + "source": [ + "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." + ] }, { "cell_type": "code", @@ -671,8 +689,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.0" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index cbb21058c2..da9ccf87ba 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -7,6 +7,8 @@ setuptools>=67.7.2 # https://github.com/pysb/pysb/pull/599 # for building the documentation, we don't care whether this fully works git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af +# For forward type definition in generate_equinox +git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python matplotlib>=3.7.1 optax nbsphinx @@ -16,6 +18,7 @@ sphinx_rtd_theme>=1.2.0 petab[vis]>=0.2.0 sphinx-autodoc-typehints ipython>=8.13.2 +h5py>=3.14.0 breathe>=4.35.0 exhale>=0.3.7 -e git+https://github.com/mithro/sphinx-contrib-mithro#egg=sphinx-contrib-exhale-multiproject&subdirectory=sphinx-contrib-exhale-multiproject diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 1625fed73d..db343019a4 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2564,6 +2564,7 @@ def _process_hybridization(self, hybridization: dict) -> None: hybridization information """ added_expressions = False + orig_obs = tuple([s.get_id() for s in self._observables]) for net_id, net in hybridization.items(): if net["static"]: continue # do not integrate into ODEs, handle in amici.jax.petab @@ -2595,7 +2596,7 @@ def _process_hybridization(self, hybridization: dict) -> None: ) outputs = { - out_var: comp + out_var: {"comp": comp, "ind": net["output_vars"][out_var]} for comp in self._components if (out_var := str(comp.get_id())) in net["output_vars"] # TODO: SYNTAX NEEDS to CHANGE @@ -2606,7 +2607,9 @@ def _process_hybridization(self, hybridization: dict) -> None: raise ValueError( f"Could not find all output variables for neural network {net_id}" ) - for iout, (out_var, comp) in enumerate(outputs.items()): + + for out_var, parts in outputs.items(): + comp = parts["comp"] # remove output from model components if isinstance(comp, Parameter): self._parameters.remove(comp) @@ -2621,7 +2624,7 @@ def _process_hybridization(self, hybridization: dict) -> None: # generate dummy Function out_val = sp.Function(net_id)( - *[input.get_id() for input in inputs], iout + *[input.get_id() for input in inputs], parts["ind"] ) # add to the model @@ -2642,6 +2645,42 @@ def _process_hybridization(self, hybridization: dict) -> None: ) added_expressions = True + observables = { + ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]} + for comp in self._components + if (ob_var := str(comp.get_id())) in net["observable_vars"] + # # TODO: SYNTAX NEEDS to CHANGE + # or (ob_var := str(comp.get_id()) + "_dot") + # in net["observable_vars"] + } + if len(observables.keys()) != len(net["observable_vars"]): + raise ValueError( + f"Could not find all observable variables for neural network {net_id}" + ) + + for ob_var, parts in observables.items(): + comp = parts["comp"] + if isinstance(comp, Observable): + self._observables.remove(comp) + else: + raise ValueError( + f"{comp.get_name()} ({type(comp)}) is not an observable." + ) + out_val = sp.Function(net_id)( + *[input.get_id() for input in inputs], parts["ind"] + ) + # add to the model + self.add_component( + Observable( + identifier=comp.get_id(), + name=net_id, + value=out_val, + ) + ) + + new_order = [orig_obs.index(s.get_id()) for s in self._observables] + self._observables = [self._observables[i] for i in new_order] + if added_expressions: # toposort expressions w_sorted = toposort_symbols( diff --git a/python/sdist/amici/jax/jaxcodeprinter.py b/python/sdist/amici/jax/jaxcodeprinter.py index f96186ac3b..ed022e9fd6 100644 --- a/python/sdist/amici/jax/jaxcodeprinter.py +++ b/python/sdist/amici/jax/jaxcodeprinter.py @@ -6,6 +6,7 @@ import sympy as sp from sympy.printing.numpy import NumPyPrinter +from sympy.core.function import UndefinedFunction def _jnp_array_str(array) -> str: @@ -42,8 +43,11 @@ def _print_Mul(self, expr: sp.Expr) -> str: return super()._print_Mul(expr) return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})" - def _print_Function(self, expr): - return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]" + def _print_Function(self, expr: sp.Expr) -> str: + if isinstance(expr.func, UndefinedFunction): + return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]" + else: + return super()._print_Function(expr) def _print_Max(self, expr: sp.Expr) -> str: """ diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 917836b6d9..d920bf5ad0 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -545,6 +545,8 @@ def simulate_condition( x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: @@ -588,6 +590,10 @@ def simulate_condition( mask for re-initialization. If `True`, the corresponding state variable is re-initialized. :param x_reinit: re-initialized state vector. If not provided, the state vector is not re-initialized. + :param init_override: + override model input e.g. with neural net outputs. If not provided, the inputs are not overridden. + :param init_override_mask: + mask for input override. If `True`, the corresponding input is replaced with value from init_override. :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. @@ -602,6 +608,11 @@ def simulate_condition( if x_preeq.shape[0]: x = x_preeq + elif init_override.shape[0]: + x_def = self._x0(t0, p) + x = jnp.squeeze( + jnp.where(init_override_mask, init_override, x_def) + ) else: x = self._x0(t0, p) diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py index a7d49bd0f7..f94a0dcca4 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/jax/nn.py @@ -30,7 +30,7 @@ def tanhshrink(x: jnp.ndarray) -> jnp.ndarray: return x - jnp.tanh(x) -def generate_equinox(nn_model: "NNModel", filename: Path | str): # noqa: F821 +def generate_equinox(nn_model: "NNModel", filename: Path | str, frozen_layers: dict = {}): # noqa: F821 # TODO: move to top level import and replace forward type definitions from petab_sciml import Layer @@ -53,6 +53,7 @@ def generate_equinox(nn_model: "NNModel", filename: Path | str): # noqa: F821 _generate_forward( node, node_indent, + frozen_layers, layers.get( node.target, Layer(layer_id="dummy", layer_type="Linear"), @@ -120,7 +121,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F "bias": "use_bias", }, "LayerNorm": { - "affine": "elementwise_affine", + "elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias) "normalized_shape": "shape", }, } @@ -149,13 +150,25 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F return f"{' ' * indent}'{layer.layer_id}': {layer_str}" -def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F821 +def _generate_forward(node: "Node", indent, frozen_layers: dict = {}, layer_type=str) -> str: # noqa: F821 if node.op == "placeholder": # TODO: inconsistent target vs name return f"{' ' * indent}{node.name} = input" if node.op == "call_module": fun_str = f"self.layers['{node.target}']" + if node.name in frozen_layers: + if frozen_layers[node.name]: + arr_attr = frozen_layers[node.name] + get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')" + replacer = ( + "replace_fn = lambda arr: jax.lax.stop_gradient(arr)" + ) + tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})" + fun_str = f"tree_{node.name}" + else: + fun_str = f"jax.lax.stop_gradient({fun_str})" + tree_string = "" if layer_type.startswith(("Conv", "Linear", "LayerNorm")): if layer_type in ("LayerNorm",): dims = f"len({fun_str}.shape)+1" @@ -190,12 +203,17 @@ def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F82 args = ", ".join([f"{arg}" for arg in node.args]) kwargs = [ - "=".join(item) for item in node.kwargs.items() if k not in ("inplace",) + f"{k}={item}" + for k, item in node.kwargs.items() + if k not in ("inplace",) ] if layer_type.startswith(("Dropout",)): kwargs += ["key=key"] kwargs_str = ", ".join(kwargs) if node.op in ("call_module", "call_function", "call_method"): - return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" + if node.name in frozen_layers: + return f"{' ' * indent}{tree_string}\n{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" + else: + return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" if node.op == "output": return f"{' ' * indent}{node.target} = {args}" diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index c57492e390..69cc413c73 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -6,7 +6,7 @@ The user generally won't have to directly call any function from this module as this will be done by :py:func:`amici.pysb_import.pysb2jax`, -:py:func:`amici.sbml_import.SbmlImporter.sbml2jax` and +:py:func:`amici.sbml_import.SbmlImporter.` and :py:func:`amici.petab_import.import_model`. """ @@ -283,12 +283,12 @@ def _generate_jax_code(self) -> None: tpl_data, ) - def _generate_nn_code(self) -> None: for net_name, net in self.hybridisation.items(): generate_equinox( net["model"], self.model_path / f"{net_name}.py", + net["frozen_layers"], ) def _implicit_roots(self) -> list[sp.Expr]: @@ -303,7 +303,6 @@ def _implicit_roots(self) -> list[sp.Expr]: roots.append(root) return roots - def set_paths(self, output_dir: str | Path | None = None) -> None: """ Set output paths for the model and create if necessary diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 29f1e65b66..520c141d9c 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -7,6 +7,7 @@ from pathlib import Path from collections.abc import Callable import logging +from typing import Union import diffrax @@ -16,10 +17,12 @@ import jaxtyping as jt import jax.lax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np import pandas as pd import petab.v1 as petab import h5py +import re from amici import _module_from_path from amici.petab.parameter_mapping import ( @@ -75,6 +78,31 @@ def jax_unscale( raise ValueError(f"Invalid parameter scaling: {scale_str}") +# IDEA: Implement this class in petab-sciml instead? +class HybridProblem(petab.Problem): + hybridization_df: pd.DataFrame + + def __init__(self, petab_problem: petab.Problem): + self.__dict__.update(petab_problem.__dict__) + self.hybridization_df = _get_hybridization_df(petab_problem) + + +def _get_hybridization_df(petab_problem): + if "sciml" in petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t", index_col=0) + for hf in petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + return hybridization_df + + +def _get_hybrid_petab_problem(petab_problem: petab.Problem): + return HybridProblem(petab_problem) + + class JAXProblem(eqx.Module): """ PEtab problem wrapper for JAX models. @@ -108,7 +136,7 @@ class JAXProblem(eqx.Module): _np_mask: np.ndarray _np_indices: np.ndarray _petab_measurement_indices: np.ndarray - _petab_problem: petab.Problem + _petab_problem: petab.Problem | HybridProblem def __init__(self, model: JAXModel, petab_problem: petab.Problem): """ @@ -121,7 +149,7 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): """ scs = petab_problem.get_simulation_conditions_from_measurement_df() self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) - self._petab_problem = petab_problem + self._petab_problem = _get_hybrid_petab_problem(petab_problem) self.parameters, self.model = self._get_nominal_parameter_values(model) self._parameter_mappings = self._get_parameter_mappings(scs) ( @@ -524,24 +552,55 @@ def _get_nominal_parameter_values( for net_id, nn in model.nns.items() } # load nn parameters from file - par_arrays = { - array_id: h5py.File(file_spec["location"], "r") - for array_id, file_spec in self._petab_problem.extensions_config[ - "array_files" - ].items() - # TODO: FIXME (https://github.com/sebapersson/petab_sciml_testsuite/issues/1) - } + par_arrays = ( + dict( + [ + ( + file_spec.split("_")[0], + h5py.File(file_spec, "r")["parameters"][ + file_spec.split("_")[0] + ], + ) + for file_spec in self._petab_problem.extensions_config[ + "sciml" + ]["array_files"] + if "parameters" in h5py.File(file_spec, "r").keys() + ] + ) + if self._petab_problem.extensions_config + else {} + ) + + nn_input_arrays = ( + dict( + [ + ( + file_spec.split("_")[0], + h5py.File(file_spec, "r")["inputs"], + ) + for file_spec in self._petab_problem.extensions_config[ + "sciml" + ]["array_files"] + if "inputs" in h5py.File(file_spec, "r").keys() + ] + ) + if self._petab_problem.extensions_config + else {} + ) # extract nominal values from petab problem for pname, row in self._petab_problem.parameter_df.iterrows(): - if (net := pname.split(".")[0]) in model.nns: + if (net := pname.split("_")[0]) in model.nns: to_set = [] nn = model_pars[net] - try: - value = float(row[petab.NOMINAL_VALUE]) - except ValueError: - value = par_arrays[row[petab.NOMINAL_VALUE]] + scalar = True + + if np.isnan(row[petab.NOMINAL_VALUE]): + value = par_arrays[net] scalar = False + else: + value = float(row[petab.NOMINAL_VALUE]) + if len(pname.split(".")) > 1: layer_name = pname.split(".")[1] layer = nn[layer_name] @@ -567,11 +626,11 @@ def _get_nominal_parameter_values( for layer, attribute in to_set: if scalar: nn[layer][attribute] = value * jnp.ones_like( - model.nns[net].layers[layer][attribute] + getattr(model.nns[net].layers[layer], attribute) ) else: nn[layer][attribute] = jnp.array( - value[layer][attribute] + value[layer][attribute][:] ) # set values in model @@ -588,6 +647,26 @@ def _get_nominal_parameter_values( model, model_pars[net_id][layer_id][attribute], ) + + # set inputs in the model if provided + if len(nn_input_arrays) > 0: + for net_id in model_pars: + input_array = { + input: { + k: jnp.array( + arr[:], + dtype=jnp.float64 + if jax.config.jax_enable_x64 + else jnp.float32, + ) + for k, arr in nn_input_arrays[net_id][input].items() + } + for input in model.nns[net_id].inputs + } + model = eqx.tree_at( + lambda model: model.nns[net_id].inputs, model, input_array + ) + return jnp.array( [ petab.scale( @@ -691,35 +770,107 @@ def _unscale( [jax_unscale(pval, scale) for pval, scale in zip(p, scales)] ) - def _eval_nn(self, output_par: str): + def _eval_nn(self, output_par: str, condition_id: str): net_id = self._petab_problem.mapping_df.loc[ output_par, petab.MODEL_ENTITY_ID ].split(".")[0] nn = self.model.nns[net_id] + def _is_net_input(model_id): + comps = model_id.split(".") + return comps[0] == net_id and comps[1].startswith("inputs") + model_id_map = ( self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] - .str.split(".") - .str[0] - == net_id + self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID].apply( + _is_net_input + ) ] .reset_index() .set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID] .to_dict() ) + condition_input_map = ( + dict( + [ + ( + petab_id, + self._petab_problem.parameter_df.loc[ + self._petab_problem.condition_df.loc[ + condition_id, petab_id + ], + petab.NOMINAL_VALUE, + ], + ) + if self._petab_problem.condition_df.loc[ + condition_id, petab_id + ] + in self._petab_problem.parameter_df.index + else ( + petab_id, + np.float64( + self._petab_problem.condition_df.loc[ + condition_id, petab_id + ] + ), + ) + for petab_id in model_id_map.values() + ] + ) + if not self._petab_problem.condition_df.empty + else {} + ) + + hybridization_parameter_map = dict( + [ + ( + petab_id, + self._petab_problem.hybridization_df.loc[ + petab_id, "targetValue" + ], + ) + for petab_id in model_id_map.values() + if petab_id in set(self._petab_problem.hybridization_df.index) + ] + ) + + # handle conditions + if len(condition_input_map) > 0: + net_input = jnp.array( + [ + condition_input_map[petab_id] + for _, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + + # handle array inputs + if isinstance(self.model.nns[net_id].inputs, dict): + net_input = jnp.array( + [ + self.model.nns[net_id].inputs[petab_id][condition_id] + if condition_id in self.model.nns[net_id].inputs[petab_id] + else self.model.nns[net_id].inputs[petab_id]["0"] + for _, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + net_input = jnp.array( [ - jax.lax.stop_gradient(self._inputs[net_id][model_id]) - if model_id in self._inputs[net_id] + jax.lax.stop_gradient(self.model.nns[net_id][model_id]) + if model_id in self.model.nns[net_id].inputs else self.get_petab_parameter_by_id(petab_id) if petab_id in self.parameter_ids else self._petab_problem.parameter_df.loc[ petab_id, petab.NOMINAL_VALUE ] + if petab_id in set(self._petab_problem.parameter_df.index) + else self._petab_problem.parameter_df.loc[ + hybridization_parameter_map[petab_id], petab.NOMINAL_VALUE + ] for model_id, petab_id in model_id_map.items() - if model_id.split(".")[1].startswith("input") ] ) return nn.forward(net_input).squeeze() @@ -728,10 +879,19 @@ def _map_model_parameter_value( self, mapping: ParameterMappingForCondition, pname: str, + condition_id: str, ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 - if pname in self.nn_output_ids: - return self._eval_nn(pname) pval = mapping.map_sim_var[pname] + if pval in self.nn_output_ids: + nn_output = self._eval_nn(pval, condition_id) + if nn_output.size > 1: + entityId = self._petab_problem.mapping_df.loc[ + pval, petab.MODEL_ENTITY_ID + ] + ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1)) + return nn_output[ind] + else: + return nn_output if isinstance(pval, Number): return pval return self.get_petab_parameter_by_id(pval) @@ -751,7 +911,9 @@ def load_model_parameters( p = jnp.array( [ - self._map_model_parameter_value(mapping, pname) + self._map_model_parameter_value( + mapping, pname, simulation_condition + ) for pname in self.model.parameter_ids ] ) @@ -928,17 +1090,21 @@ def _prepare_conditions( p_array = jnp.stack( [self.load_model_parameters(sc) for sc in conditions] ) - unscaled_parameters = jnp.stack( - [ - jax_unscale( - self.parameters[ip], - self._petab_problem.parameter_df.loc[ - p_id, petab.PARAMETER_SCALE - ], - ) - for ip, p_id in enumerate(self.parameter_ids) - ] - ) + + if self.parameters.size: + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], + self._petab_problem.parameter_df.loc[ + p_id, petab.PARAMETER_SCALE + ], + ) + for ip, p_id in enumerate(self.parameter_ids) + ] + ) + else: + unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) if op_numeric is not None and op_numeric.size: op_array = jnp.where( @@ -994,6 +1160,8 @@ def run_simulation( nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1059,6 +1227,8 @@ def run_simulation( x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, + init_override=init_override, + init_override_mask=jax.lax.stop_gradient(init_override_mask), ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), solver=solver, controller=controller, @@ -1118,6 +1288,38 @@ def run_simulations( self._np_indices, ) ) + + init_override_mask = jnp.stack( + [ + jnp.array( + [ + True + if p + in set(self._parameter_mappings[sc].map_sim_var.keys()) + else False + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) + init_override = jnp.stack( + [ + jnp.array( + [ + self._eval_nn( + self._parameter_mappings[sc].map_sim_var[p], sc + ) + if p + in set(self._parameter_mappings[sc].map_sim_var.keys()) + else 1.0 + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) + return self.run_simulation( p_array, self._ts_dyn, @@ -1129,6 +1331,8 @@ def run_simulations( np_array, mask_reinit_array, x_reinit_array, + init_override, + init_override_mask, solver, controller, root_finder, diff --git a/python/sdist/amici/petab/parameter_mapping.py b/python/sdist/amici/petab/parameter_mapping.py index beac7321bf..0f70ce0d71 100644 --- a/python/sdist/amici/petab/parameter_mapping.py +++ b/python/sdist/amici/petab/parameter_mapping.py @@ -378,13 +378,20 @@ def create_parameter_mapping( if parameter_mapping_kwargs is None: parameter_mapping_kwargs = {} + # TODO: Add support for conditions with sciml mappings in petab library + mapping = ( + None + if "sciml" in petab_problem.extensions_config + else petab_problem.mapping_df + ) + prelim_parameter_mapping = ( petab.get_optimization_to_simulation_parameter_mapping( condition_df=petab_problem.condition_df, measurement_df=petab_problem.measurement_df, parameter_df=petab_problem.parameter_df, observable_df=petab_problem.observable_df, - mapping_df=petab_problem.mapping_df, + mapping_df=mapping, model=petab_problem.model, simulation_conditions=simulation_conditions, fill_fixed_parameters=fill_fixed_parameters, @@ -585,6 +592,24 @@ def create_parameter_mapping_for_condition( ) logger.debug(f"Merged: {condition_map_sim_var}") + if "sciml" in petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t") + for hf in petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + for net_id, config in petab_problem.extensions_config["sciml"][ + "neural_nets" + ].items(): + if config["static"]: + for _, row in hybridization_df.iterrows(): + if row["targetValue"].startswith(net_id): + condition_map_sim_var[row["targetId"]] = row[ + "targetValue" + ] + parameter_mapping_for_condition = ParameterMappingForCondition( map_preeq_fix=condition_map_preeq_fix, map_sim_fix=condition_map_sim_fix, diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 8119d970ac..0317c14256 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -9,6 +9,7 @@ import os import shutil from pathlib import Path +import re import amici import pandas as pd @@ -129,14 +130,17 @@ def import_petab_problem( logger.info(f"Compiling model {model_name} to {model_output_dir}.") - if "neural_nets" in petab_problem.extensions_config: # TODO: fixme + if "sciml" in petab_problem.extensions_config: from petab_sciml.standard import NNModelStandard - config = petab_problem.extensions_config + config = petab_problem.extensions_config["sciml"] # TODO: only accept YAML format for now - hybridization_table = pd.read_csv( - config["hybridization_file"], sep="\t" - ) + hybridizations = [ + pd.read_csv(hf, sep="\t") + for hf in config["hybridization_files"] + ] + hybridization_table = pd.concat(hybridizations) + input_mapping = dict( zip( hybridization_table["targetId"], @@ -149,6 +153,12 @@ def import_petab_problem( hybridization_table["targetId"], ) ) + observable_mapping = dict( + zip( + petab_problem.observable_df["observableFormula"], + petab_problem.observable_df.index, + ) + ) hybridization = { net_id: { "model": NNModelStandard.load_data( @@ -166,20 +176,65 @@ def import_petab_problem( .to_dict() .items() if model_id.split(".")[1].startswith("input") + and petab_id in input_mapping.keys() ], - "output_vars": [ - output_mapping[petab_id] - for petab_id, model_id in petab_problem.mapping_df.loc[ - petab_problem.mapping_df[petab.MODEL_ENTITY_ID] - .str.split(".") - .str[0] - == net_id, - petab.MODEL_ENTITY_ID, + "output_vars": dict( + [ + ( + output_mapping[petab_id], + _get_net_index(model_id), + ) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in output_mapping.keys() ] - .to_dict() - .items() - if model_id.split(".")[1].startswith("output") - ], + ), + "observable_vars": dict( + [ + ( + observable_mapping[petab_id], + _get_net_index(model_id), + ) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in observable_mapping.keys() + ] + ), + "frozen_layers": dict( + [ + _get_frozen_layers(model_id) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if petab_id in petab_problem.parameter_df.index + and petab_problem.parameter_df.loc[ + petab_id, petab.ESTIMATE + ] + == 0 + ] + ), **net_config, } for net_id, net_config in config["neural_nets"].items() @@ -233,3 +288,21 @@ def import_petab_problem( ) return model + + +def _get_net_index(model_id: str): + matches = re.findall(r"\[(\d+)\]", model_id) + if matches: + return int(matches[-1]) + + +def _get_frozen_layers(model_id): + layers = re.findall(r"\[(.*?)\]", model_id) + array_attr = model_id.split(".")[-1] + layer_id = layers[0] if len(layers) else None + array_attr = array_attr if array_attr in ("weight", "bias") else None + return layer_id, array_attr + + +# for backwards compatibility +import_model = import_model_sbml diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 9b46f80fc2..3f00a5747d 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -50,9 +50,7 @@ def change_directory(destination): def _reshape_flat_array(array_flat): array_flat["ix"] = array_flat["ix"].astype(str) - ix_cols = [ - f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";"))) - ] + ix_cols = [f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";")))] if len(ix_cols) == 1: array_flat[ix_cols[0]] = array_flat["ix"].apply(int) else: @@ -66,9 +64,7 @@ def _reshape_flat_array(array_flat): return array -@pytest.mark.parametrize( - "test", sorted([d.stem for d in net_cases_dir.glob("[0-9]*")]) -) +@pytest.mark.parametrize("test", sorted([d.stem for d in net_cases_dir.glob("[0-9]*")])) def test_net(test): test_dir = net_cases_dir / test with open(test_dir / "solutions.yaml") as f: @@ -78,107 +74,101 @@ def test_net(test): net_file = cases_dir / test.replace("_alt", "") / solutions["net_file"] else: net_file = test_dir / solutions["net_file"] - ml_models = NNModelStandard.load_data(net_file) + ml_model = NNModelStandard.load_data(net_file) nets = {} outdir = Path(__file__).parent / "models" / test - for ml_model in ml_models.models: - module_dir = outdir / f"{ml_model.mlmodel_id}.py" - if test in ( - "002", - "009", - "018", - "019", - "020", - "021", - "022", - "042", - "043", - "044", - "045", - "046", - "047", - "048", - ): - with pytest.raises(NotImplementedError): - generate_equinox(ml_model, module_dir) - return - generate_equinox(ml_model, module_dir) - nets[ml_model.mlmodel_id] = amici._module_from_path( - ml_model.mlmodel_id, module_dir - ).net + module_dir = outdir / f"{ml_model.nn_model_id}.py" + if test in ( + "002", + "009", + "018", + "019", + "020", + "021", + "022", + "042", + "043", + "044", + "045", + "046", + "047", + "048", + ): + with pytest.raises(NotImplementedError): + generate_equinox(ml_model, module_dir) + return + generate_equinox(ml_model, module_dir) + nets[ml_model.nn_model_id] = amici._module_from_path( + ml_model.nn_model_id, module_dir + ).net for input_file, par_file, output_file in zip( solutions["net_input"], solutions.get("net_ps", solutions["net_input"]), solutions["net_output"], ): - input = h5py.File(test_dir / input_file, "r")["input"][:] - output = h5py.File(test_dir / output_file, "r")["output"][:] + input = h5py.File(test_dir / input_file, "r")["inputs"]["input0"]["data"][:] + output = h5py.File(test_dir / output_file, "r")["outputs"]["output0"]["data"][:] if "net_ps" in solutions: par = h5py.File(test_dir / par_file, "r") - for ml_model in ml_models.models: - net = nets[ml_model.mlmodel_id](jr.PRNGKey(0)) - for layer in net.layers.keys(): - if ( - isinstance(net.layers[layer], eqx.Module) - and hasattr(net.layers[layer], "weight") - and net.layers[layer].weight is not None - ): - w = par[layer]["weight"][:] - if isinstance(net.layers[layer], eqx.nn.ConvTranspose): - # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose - w = np.flip( - w, axis=tuple(range(2, w.ndim)) - ).swapaxes(0, 1) - assert w.shape == net.layers[layer].weight.shape - net = eqx.tree_at( - lambda x: x.layers[layer].weight, - net, - jnp.array(w), - ) - if ( - isinstance(net.layers[layer], eqx.Module) - and hasattr(net.layers[layer], "bias") - and net.layers[layer].bias is not None + net = nets[ml_model.nn_model_id](jr.PRNGKey(0)) + for layer in net.layers.keys(): + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "weight") + and net.layers[layer].weight is not None + ): + w = par["parameters"][ml_model.nn_model_id][layer]["weight"][:] + if isinstance(net.layers[layer], eqx.nn.ConvTranspose): + # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose + w = np.flip(w, axis=tuple(range(2, w.ndim))).swapaxes(0, 1) + assert w.shape == net.layers[layer].weight.shape + net = eqx.tree_at( + lambda x: x.layers[layer].weight, + net, + jnp.array(w), + ) + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "bias") + and net.layers[layer].bias is not None + ): + b = par["parameters"][ml_model.nn_model_id][layer]["bias"][:] + if isinstance( + net.layers[layer], + eqx.nn.Conv | eqx.nn.ConvTranspose, ): - b = par[layer]["bias"][:] - if isinstance( - net.layers[layer], - eqx.nn.Conv | eqx.nn.ConvTranspose, - ): - b = np.expand_dims( - b, - tuple( - range( - 1, - net.layers[layer].num_spatial_dims + 1, - ) - ), - ) - assert b.shape == net.layers[layer].bias.shape - net = eqx.tree_at( - lambda x: x.layers[layer].bias, - net, - jnp.array(b), + b = np.expand_dims( + b, + tuple( + range( + 1, + net.layers[layer].num_spatial_dims + 1, + ) + ), ) - net = eqx.nn.inference_mode(net) + assert b.shape == net.layers[layer].bias.shape + net = eqx.tree_at( + lambda x: x.layers[layer].bias, + net, + jnp.array(b), + ) + net = eqx.nn.inference_mode(net) - if test == "net_004_alt": - return # skipping, no support for non-cross-correlation in equinox + if test == "net_004_alt": + return # skipping, no support for non-cross-correlation in equinox - np.testing.assert_allclose( - net.forward(input), - output, - atol=1e-3, - rtol=1e-3, - ) + np.testing.assert_allclose( + net.forward(input), + output, + atol=1e-3, + rtol=1e-3, + ) -@pytest.mark.parametrize( - "test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")]) -) +@pytest.mark.parametrize("test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")])) def test_ude(test): test_dir = ude_cases_dir / test with open(test_dir / "petab" / "problem.yaml") as f: @@ -201,13 +191,6 @@ def test_ude(test): jax_problem = JAXProblem(jax_model, petab_problem) # llh - if test in ( - "004", - "016", - ): - with pytest.raises(NotImplementedError): - run_simulations(jax_problem) - return llh, r = run_simulations(jax_problem) np.testing.assert_allclose( llh, @@ -246,6 +229,7 @@ def test_ude(test): expected = pd.read_csv(test_dir / file, sep="\t").set_index( petab.PARAMETER_ID ) + for ip in expected.index: if ip in jax_problem.parameter_ids: actual_dict[ip] = sllh.parameters[ @@ -260,13 +244,9 @@ def test_ude(test): ) else: expected = h5py.File(test_dir / file, "r") - for layer_name, layer in jax_problem.model.nns[ - component - ].layers.items(): + for layer_name, layer in jax_problem.model.nns[component].layers.items(): for attribute in dir(layer): - if not isinstance( - getattr(layer, attribute), jax.numpy.ndarray - ): + if not isinstance(getattr(layer, attribute), jax.numpy.ndarray): continue actual = getattr( sllh.model.nns[component].layers[layer_name], attribute @@ -279,9 +259,21 @@ def test_ude(test): actual.swapaxes(0, 1), axis=tuple(range(2, actual.ndim)), ) - np.testing.assert_allclose( - actual, - expected[layer_name][attribute][:], - atol=solutions["tol_grad_llh"], - rtol=solutions["tol_grad_llh"], - ) + if ( + np.squeeze( + expected["parameters"][component][layer_name][attribute][:] + ).size + == 0 + ): + assert np.all(actual == 0.0) + else: + np.testing.assert_allclose( + np.squeeze(actual), + np.squeeze( + expected["parameters"][component][layer_name][ + attribute + ][:] + ), + atol=solutions["tol_grad_llh"], + rtol=solutions["tol_grad_llh"], + ) diff --git a/tests/sciml/testsuite b/tests/sciml/testsuite index 3d5f91543d..596cf82a14 160000 --- a/tests/sciml/testsuite +++ b/tests/sciml/testsuite @@ -1 +1 @@ -Subproject commit 3d5f91543d000f2468c7380853db4c0206596a00 +Subproject commit 596cf82a145093bb893420d79ea93be5ebfc725b