From ec1ed55a31bbb41a8ad3f6a49b9448684a617f72 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 1 Sep 2025 16:37:18 +0100 Subject: [PATCH 01/20] Getting test_net petab-sciml tests to pass - update sciml testsuite submodule to point at main - fix a eqx LayerNorm deprecation warning - fix string formatting of bool in kwarg - updates to test code driven by updated sciml format --- python/sdist/amici/jax/nn.py | 4 +- tests/sciml/test_sciml.py | 160 +++++++++++++++++------------------ tests/sciml/testsuite | 2 +- 3 files changed, 83 insertions(+), 83 deletions(-) diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py index a7d49bd0f7..e26d58693e 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/jax/nn.py @@ -120,7 +120,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F "bias": "use_bias", }, "LayerNorm": { - "affine": "elementwise_affine", + "elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias) "normalized_shape": "shape", }, } @@ -190,7 +190,7 @@ def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F82 args = ", ".join([f"{arg}" for arg in node.args]) kwargs = [ - "=".join(item) for item in node.kwargs.items() if k not in ("inplace",) + f"{k}={item}" for k, item in node.kwargs.items() if k not in ("inplace",) ] if layer_type.startswith(("Dropout",)): kwargs += ["key=key"] diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 9b46f80fc2..72629d023d 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -78,102 +78,102 @@ def test_net(test): net_file = cases_dir / test.replace("_alt", "") / solutions["net_file"] else: net_file = test_dir / solutions["net_file"] - ml_models = NNModelStandard.load_data(net_file) + ml_model = NNModelStandard.load_data(net_file) nets = {} outdir = Path(__file__).parent / "models" / test - for ml_model in ml_models.models: - module_dir = outdir / f"{ml_model.mlmodel_id}.py" - if test in ( - "002", - "009", - "018", - "019", - "020", - "021", - "022", - "042", - "043", - "044", - "045", - "046", - "047", - "048", - ): - with pytest.raises(NotImplementedError): - generate_equinox(ml_model, module_dir) - return - generate_equinox(ml_model, module_dir) - nets[ml_model.mlmodel_id] = amici._module_from_path( - ml_model.mlmodel_id, module_dir - ).net + module_dir = outdir / f"{ml_model.nn_model_id}.py" + if test in ( + "002", + "009", + "018", + "019", + "020", + "021", + "022", + "042", + "043", + "044", + "045", + "046", + "047", + "048", + ): + with pytest.raises(NotImplementedError): + generate_equinox(ml_model, module_dir) + return + generate_equinox(ml_model, module_dir) + nets[ml_model.nn_model_id] = amici._module_from_path( + ml_model.nn_model_id, module_dir + ).net for input_file, par_file, output_file in zip( solutions["net_input"], solutions.get("net_ps", solutions["net_input"]), solutions["net_output"], ): - input = h5py.File(test_dir / input_file, "r")["input"][:] - output = h5py.File(test_dir / output_file, "r")["output"][:] + input = h5py.File(test_dir / input_file, "r")["inputs"]["input0"]["data"][:] + output = h5py.File(test_dir / output_file, "r")["outputs"]["output0"]["data"][:] if "net_ps" in solutions: par = h5py.File(test_dir / par_file, "r") - for ml_model in ml_models.models: - net = nets[ml_model.mlmodel_id](jr.PRNGKey(0)) - for layer in net.layers.keys(): - if ( - isinstance(net.layers[layer], eqx.Module) - and hasattr(net.layers[layer], "weight") - and net.layers[layer].weight is not None - ): - w = par[layer]["weight"][:] - if isinstance(net.layers[layer], eqx.nn.ConvTranspose): - # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose - w = np.flip( - w, axis=tuple(range(2, w.ndim)) - ).swapaxes(0, 1) - assert w.shape == net.layers[layer].weight.shape - net = eqx.tree_at( - lambda x: x.layers[layer].weight, - net, - jnp.array(w), - ) - if ( - isinstance(net.layers[layer], eqx.Module) - and hasattr(net.layers[layer], "bias") - and net.layers[layer].bias is not None + net = nets[ml_model.nn_model_id](jr.PRNGKey(0)) + for layer in net.layers.keys(): + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "weight") + and net.layers[layer].weight is not None + ): + # ?? grabbing weights from the parameters file ?? need to check if they are present in above if condition ?? + w = par["parameters"][ml_model.nn_model_id][layer]["weight"][:] + if isinstance(net.layers[layer], eqx.nn.ConvTranspose): + # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose + w = np.flip( + w, axis=tuple(range(2, w.ndim)) + ).swapaxes(0, 1) + assert w.shape == net.layers[layer].weight.shape + net = eqx.tree_at( + lambda x: x.layers[layer].weight, + net, + jnp.array(w), + ) + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "bias") + and net.layers[layer].bias is not None + ): + # ?? grabbing biases from the parameters file ?? need to check if they are present in above if condition ?? + b = par["parameters"][ml_model.nn_model_id][layer]["bias"][:] + if isinstance( + net.layers[layer], + eqx.nn.Conv | eqx.nn.ConvTranspose, ): - b = par[layer]["bias"][:] - if isinstance( - net.layers[layer], - eqx.nn.Conv | eqx.nn.ConvTranspose, - ): - b = np.expand_dims( - b, - tuple( - range( - 1, - net.layers[layer].num_spatial_dims + 1, - ) - ), - ) - assert b.shape == net.layers[layer].bias.shape - net = eqx.tree_at( - lambda x: x.layers[layer].bias, - net, - jnp.array(b), + b = np.expand_dims( + b, + tuple( + range( + 1, + net.layers[layer].num_spatial_dims + 1, + ) + ), ) - net = eqx.nn.inference_mode(net) + assert b.shape == net.layers[layer].bias.shape + net = eqx.tree_at( + lambda x: x.layers[layer].bias, + net, + jnp.array(b), + ) + net = eqx.nn.inference_mode(net) - if test == "net_004_alt": - return # skipping, no support for non-cross-correlation in equinox + if test == "net_004_alt": + return # skipping, no support for non-cross-correlation in equinox - np.testing.assert_allclose( - net.forward(input), - output, - atol=1e-3, - rtol=1e-3, - ) + np.testing.assert_allclose( + net.forward(input), + output, + atol=1e-3, + rtol=1e-3, + ) @pytest.mark.parametrize( diff --git a/tests/sciml/testsuite b/tests/sciml/testsuite index 3d5f91543d..596cf82a14 160000 --- a/tests/sciml/testsuite +++ b/tests/sciml/testsuite @@ -1 +1 @@ -Subproject commit 3d5f91543d000f2468c7380853db4c0206596a00 +Subproject commit 596cf82a145093bb893420d79ea93be5ebfc725b From 2e1419815ef0d6c6ebfe708667a90517b51d57ce Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 1 Sep 2025 16:37:44 +0100 Subject: [PATCH 02/20] Implementing features for a subset of ude petab_sciml test cases. Excludes: - frozen nn layers - nns in observable formulae --- python/sdist/amici/jax/model.py | 11 + python/sdist/amici/jax/nn.py | 6 +- python/sdist/amici/jax/ode_export.py | 2 - python/sdist/amici/jax/petab.py | 222 ++++++++++++++++-- python/sdist/amici/petab/parameter_mapping.py | 24 +- python/sdist/amici/petab/petab_import.py | 16 +- tests/sciml/test_sciml.py | 36 ++- 7 files changed, 267 insertions(+), 50 deletions(-) diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 917836b6d9..d920bf5ad0 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -545,6 +545,8 @@ def simulate_condition( x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override: jt.Float[jt.Array, "*nx"] = jnp.array([]), + init_override_mask: jt.Bool[jt.Array, "*nx"] = jnp.array([]), ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]), ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: @@ -588,6 +590,10 @@ def simulate_condition( mask for re-initialization. If `True`, the corresponding state variable is re-initialized. :param x_reinit: re-initialized state vector. If not provided, the state vector is not re-initialized. + :param init_override: + override model input e.g. with neural net outputs. If not provided, the inputs are not overridden. + :param init_override_mask: + mask for input override. If `True`, the corresponding input is replaced with value from init_override. :param ts_mask: mask to remove (padded) time points. If `True`, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2. @@ -602,6 +608,11 @@ def simulate_condition( if x_preeq.shape[0]: x = x_preeq + elif init_override.shape[0]: + x_def = self._x0(t0, p) + x = jnp.squeeze( + jnp.where(init_override_mask, init_override, x_def) + ) else: x = self._x0(t0, p) diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py index e26d58693e..d409063c20 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/jax/nn.py @@ -120,7 +120,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F "bias": "use_bias", }, "LayerNorm": { - "elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias) + "elementwise_affine": "use_bias", # Deprecation warning - replace LayerNorm(elementwise_affine) with LayerNorm(use_bias) "normalized_shape": "shape", }, } @@ -190,7 +190,9 @@ def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F82 args = ", ".join([f"{arg}" for arg in node.args]) kwargs = [ - f"{k}={item}" for k, item in node.kwargs.items() if k not in ("inplace",) + f"{k}={item}" + for k, item in node.kwargs.items() + if k not in ("inplace",) ] if layer_type.startswith(("Dropout",)): kwargs += ["key=key"] diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index c57492e390..53b8479155 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -283,7 +283,6 @@ def _generate_jax_code(self) -> None: tpl_data, ) - def _generate_nn_code(self) -> None: for net_name, net in self.hybridisation.items(): generate_equinox( @@ -303,7 +302,6 @@ def _implicit_roots(self) -> list[sp.Expr]: roots.append(root) return roots - def set_paths(self, output_dir: str | Path | None = None) -> None: """ Set output paths for the model and create if necessary diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 29f1e65b66..86bed469c6 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -20,6 +20,7 @@ import pandas as pd import petab.v1 as petab import h5py +import re from amici import _module_from_path from amici.petab.parameter_mapping import ( @@ -95,6 +96,7 @@ class JAXProblem(eqx.Module): model: JAXModel simulation_conditions: tuple[tuple[str, ...], ...] _parameter_mappings: dict[str, ParameterMappingForCondition] + _hybridization_df: pd.DataFrame _ts_dyn: np.ndarray _ts_posteq: np.ndarray _my: np.ndarray @@ -122,6 +124,7 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): scs = petab_problem.get_simulation_conditions_from_measurement_df() self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) self._petab_problem = petab_problem + self._hybridization_df = self._get_hybridization_df() self.parameters, self.model = self._get_nominal_parameter_values(model) self._parameter_mappings = self._get_parameter_mappings(scs) ( @@ -524,28 +527,51 @@ def _get_nominal_parameter_values( for net_id, nn in model.nns.items() } # load nn parameters from file - par_arrays = { - array_id: h5py.File(file_spec["location"], "r") - for array_id, file_spec in self._petab_problem.extensions_config[ - "array_files" - ].items() - # TODO: FIXME (https://github.com/sebapersson/petab_sciml_testsuite/issues/1) - } + # ?? dict similar to the one above ?? {nn_id: {layer_id: {weight, bias}}} ?? + par_arrays = dict( + [ + ( + file_spec.split("_")[0], + h5py.File(file_spec, "r")["parameters"][ + file_spec.split("_")[0] + ], + ) + for file_spec in self._petab_problem.extensions_config[ + "sciml" + ]["array_files"] + if "parameters" in h5py.File(file_spec, "r").keys() + ] + ) + + nn_input_arrays = dict( + [ + (file_spec.split("_")[0], h5py.File(file_spec, "r")["inputs"]) + for file_spec in self._petab_problem.extensions_config[ + "sciml" + ]["array_files"] + if "inputs" in h5py.File(file_spec, "r").keys() + ] + ) # extract nominal values from petab problem for pname, row in self._petab_problem.parameter_df.iterrows(): - if (net := pname.split(".")[0]) in model.nns: + if (net := pname.split("_")[0]) in model.nns: to_set = [] nn = model_pars[net] - try: - value = float(row[petab.NOMINAL_VALUE]) - except ValueError: - value = par_arrays[row[petab.NOMINAL_VALUE]] + scalar = True + + # ?? When value is NaN do we want to go to par_arrays ?? + if np.isnan(row[petab.NOMINAL_VALUE]): + value = par_arrays[net] scalar = False + else: + value = float(row[petab.NOMINAL_VALUE]) + if len(pname.split(".")) > 1: layer_name = pname.split(".")[1] layer = nn[layer_name] if len(pname.split(".")) > 2: + # ?? Recursion needed ?? attribute_name = pname.split(".")[2] to_set.append((layer_name, attribute_name)) else: @@ -567,11 +593,11 @@ def _get_nominal_parameter_values( for layer, attribute in to_set: if scalar: nn[layer][attribute] = value * jnp.ones_like( - model.nns[net].layers[layer][attribute] + getattr(model.nns[net].layers[layer], attribute) ) else: nn[layer][attribute] = jnp.array( - value[layer][attribute] + value[layer][attribute][:] ) # set values in model @@ -588,6 +614,38 @@ def _get_nominal_parameter_values( model, model_pars[net_id][layer_id][attribute], ) + + # set inputs in the model if provided + if len(nn_input_arrays) > 0: + for net_id in model_pars: + for input in model.nns[net_id].inputs: + input_array = dict( + [ + ( + input, + dict( + [ + ( + k, + jnp.array( + nn_input_arrays[net_id][input][ + k + ][:], + dtype=jnp.float64, + ), + ) # ?? hardcoded dtype not ideal ?? could infer from env somehow ?? + for k in nn_input_arrays[net_id][ + input + ].keys() + ] + ), + ) + ] + ) + model = eqx.tree_at( + lambda model: model.nns[net_id].inputs, model, input_array + ) + return jnp.array( [ petab.scale( @@ -629,6 +687,17 @@ def _get_inputs(self): ].values.reshape(shape) return inputs + def _get_hybridization_df(self): + if "sciml" in self._petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t", index_col=0) + for hf in self._petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + return hybridization_df + @property def parameter_ids(self) -> list[str]: """ @@ -691,7 +760,7 @@ def _unscale( [jax_unscale(pval, scale) for pval, scale in zip(p, scales)] ) - def _eval_nn(self, output_par: str): + def _eval_nn(self, output_par: str, condition_id: str): net_id = self._petab_problem.mapping_df.loc[ output_par, petab.MODEL_ENTITY_ID ].split(".")[0] @@ -709,15 +778,70 @@ def _eval_nn(self, output_par: str): .to_dict() ) + condition_parameter_map = ( + dict( + [ + ( + petab_id, + self._petab_problem.parameter_df.loc[ + self._petab_problem.condition_df.loc[ + condition_id, petab_id + ], + petab.NOMINAL_VALUE, + ], + ) + if "input" + in self._petab_problem.condition_df.loc[ + condition_id, petab_id + ] + else ( + petab_id, + np.float64( + self._petab_problem.condition_df.loc[ + condition_id, petab_id + ] + ), + ) + for petab_id in [ + s for s in model_id_map.values() if "input" in s + ] + ] + ) + if not self._petab_problem.condition_df.empty + else {} + ) + + hybridization_parameter_map = dict( + [ + (petab_id, self._hybridization_df.loc[petab_id, "targetValue"]) + for petab_id in model_id_map.values() + if petab_id in set(self._hybridization_df.index) + ] + ) + + # ?? conditional nightmare ?? refactor ?? net_input = jnp.array( [ - jax.lax.stop_gradient(self._inputs[net_id][model_id]) - if model_id in self._inputs[net_id] + jax.lax.stop_gradient(self.model.nns[net_id][model_id]) + if model_id in self.model.nns[net_id].inputs else self.get_petab_parameter_by_id(petab_id) if petab_id in self.parameter_ids else self._petab_problem.parameter_df.loc[ petab_id, petab.NOMINAL_VALUE ] + if petab_id in set(self._petab_problem.parameter_df.index) + else self.model.nns[net_id].inputs[petab_id][condition_id] + if isinstance(self.model.nns[net_id].inputs, dict) + and condition_id in self.model.nns[net_id].inputs[petab_id] + else self.model.nns[net_id].inputs[petab_id][ + "0" + ] # ?? "0" always the key if inputs for all conditions ?? + if petab_id in self.model.nns[net_id].inputs + else self._petab_problem.parameter_df.loc[ + hybridization_parameter_map[petab_id], petab.NOMINAL_VALUE + ] + if self._petab_problem.condition_df.empty + else condition_parameter_map[petab_id] for model_id, petab_id in model_id_map.items() if model_id.split(".")[1].startswith("input") ] @@ -728,10 +852,20 @@ def _map_model_parameter_value( self, mapping: ParameterMappingForCondition, pname: str, + condition_id: str, ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 - if pname in self.nn_output_ids: - return self._eval_nn(pname) pval = mapping.map_sim_var[pname] + if pval in self.nn_output_ids: + nn_output = self._eval_nn(pval, condition_id) + if nn_output.size > 1: + # ?? can this approach work for single dimension return ?? maybe remove the squeeze from _eval_nn ?? + entityId = self._petab_problem.mapping_df.loc[ + pval, petab.MODEL_ENTITY_ID + ] + ind = int(re.search(r"\[\d+\]\[(\d+)\]", entityId).group(1)) + return nn_output[ind] + else: + return nn_output if isinstance(pval, Number): return pval return self.get_petab_parameter_by_id(pval) @@ -751,7 +885,9 @@ def load_model_parameters( p = jnp.array( [ - self._map_model_parameter_value(mapping, pname) + self._map_model_parameter_value( + mapping, pname, simulation_condition + ) for pname in self.model.parameter_ids ] ) @@ -980,6 +1116,8 @@ def _prepare_conditions( in_axes={ "max_steps": None, "self": None, + "init_override": None, # ?? performance hit ?? flip arrays to avoid ?? + "init_override_mask": None, }, # only list arguments here where eqx.is_array(0) is not the right thing ) def run_simulation( @@ -994,6 +1132,8 @@ def run_simulation( nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + init_override: jt.Float[jt.Array, "np"], # ?? what do these annotations mean ?? + init_override_mask: jt.Bool[jt.Array, "np"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1059,6 +1199,8 @@ def run_simulation( x_preeq=x_preeq, mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, + init_override=init_override, + init_override_mask=init_override_mask, ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), solver=solver, controller=controller, @@ -1118,6 +1260,44 @@ def run_simulations( self._np_indices, ) ) + # state values override? + init_override_mask = jnp.stack( + [ + jnp.array( + [ + True + if p + in set(self._parameter_mappings[sc].map_sim_var.keys()) + else False + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) + if not init_override_mask.any(): + init_override_mask = jnp.array([]) + init_override = jnp.array([]) + else: + init_override = jnp.stack( + [ + jnp.array( + [ + self._eval_nn( + self._parameter_mappings[sc].map_sim_var[p], sc + ) + if p + in set( + self._parameter_mappings[sc].map_sim_var.keys() + ) + else 1.0 # ?? dummy value - shouldn't matter ?? + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) + return self.run_simulation( p_array, self._ts_dyn, @@ -1129,6 +1309,8 @@ def run_simulations( np_array, mask_reinit_array, x_reinit_array, + init_override, + init_override_mask, solver, controller, root_finder, diff --git a/python/sdist/amici/petab/parameter_mapping.py b/python/sdist/amici/petab/parameter_mapping.py index beac7321bf..7264ea3a14 100644 --- a/python/sdist/amici/petab/parameter_mapping.py +++ b/python/sdist/amici/petab/parameter_mapping.py @@ -384,7 +384,7 @@ def create_parameter_mapping( measurement_df=petab_problem.measurement_df, parameter_df=petab_problem.parameter_df, observable_df=petab_problem.observable_df, - mapping_df=petab_problem.mapping_df, + # mapping_df=petab_problem.mapping_df, model=petab_problem.model, simulation_conditions=simulation_conditions, fill_fixed_parameters=fill_fixed_parameters, @@ -394,6 +394,8 @@ def create_parameter_mapping( ) ) + # ?? put mappings in later ?? after mapping for condition ?? will there be a performance regression as a result ?? + parameter_mapping = ParameterMapping() for (_, condition), prelim_mapping_for_condition in zip( simulation_conditions.iterrows(), prelim_parameter_mapping, strict=True @@ -585,6 +587,26 @@ def create_parameter_mapping_for_condition( ) logger.debug(f"Merged: {condition_map_sim_var}") + # ?? right place for static hybridization here ?? + + if "sciml" in petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t") + for hf in petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + for net_id, config in petab_problem.extensions_config["sciml"][ + "neural_nets" + ].items(): + if config["static"]: + for _, row in hybridization_df.iterrows(): + if row["targetValue"].startswith(net_id): + condition_map_sim_var[row["targetId"]] = row[ + "targetValue" + ] + parameter_mapping_for_condition = ParameterMappingForCondition( map_preeq_fix=condition_map_preeq_fix, map_sim_fix=condition_map_sim_fix, diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 8119d970ac..8f996cc60b 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -129,14 +129,17 @@ def import_petab_problem( logger.info(f"Compiling model {model_name} to {model_output_dir}.") - if "neural_nets" in petab_problem.extensions_config: # TODO: fixme + if "sciml" in petab_problem.extensions_config: from petab_sciml.standard import NNModelStandard - config = petab_problem.extensions_config + config = petab_problem.extensions_config["sciml"] # TODO: only accept YAML format for now - hybridization_table = pd.read_csv( - config["hybridization_file"], sep="\t" - ) + hybridizations = [ + pd.read_csv(hf, sep="\t") + for hf in config["hybridization_files"] + ] + hybridization_table = pd.concat(hybridizations) + input_mapping = dict( zip( hybridization_table["targetId"], @@ -166,6 +169,7 @@ def import_petab_problem( .to_dict() .items() if model_id.split(".")[1].startswith("input") + and petab_id in input_mapping.keys() ], "output_vars": [ output_mapping[petab_id] @@ -179,7 +183,9 @@ def import_petab_problem( .to_dict() .items() if model_id.split(".")[1].startswith("output") + and petab_id in output_mapping.keys() ], + # ?? static included here ?? and handled later ?? **net_config, } for net_id, net_config in config["neural_nets"].items() diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 72629d023d..a39c5b8c9f 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -50,9 +50,7 @@ def change_directory(destination): def _reshape_flat_array(array_flat): array_flat["ix"] = array_flat["ix"].astype(str) - ix_cols = [ - f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";"))) - ] + ix_cols = [f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";")))] if len(ix_cols) == 1: array_flat[ix_cols[0]] = array_flat["ix"].apply(int) else: @@ -66,9 +64,7 @@ def _reshape_flat_array(array_flat): return array -@pytest.mark.parametrize( - "test", sorted([d.stem for d in net_cases_dir.glob("[0-9]*")]) -) +@pytest.mark.parametrize("test", sorted([d.stem for d in net_cases_dir.glob("[0-9]*")])) def test_net(test): test_dir = net_cases_dir / test with open(test_dir / "solutions.yaml") as f: @@ -128,9 +124,7 @@ def test_net(test): w = par["parameters"][ml_model.nn_model_id][layer]["weight"][:] if isinstance(net.layers[layer], eqx.nn.ConvTranspose): # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose - w = np.flip( - w, axis=tuple(range(2, w.ndim)) - ).swapaxes(0, 1) + w = np.flip(w, axis=tuple(range(2, w.ndim))).swapaxes(0, 1) assert w.shape == net.layers[layer].weight.shape net = eqx.tree_at( lambda x: x.layers[layer].weight, @@ -176,9 +170,7 @@ def test_net(test): ) -@pytest.mark.parametrize( - "test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")]) -) +@pytest.mark.parametrize("test", sorted([d.stem for d in ude_cases_dir.glob("[0-9]*")])) def test_ude(test): test_dir = ude_cases_dir / test with open(test_dir / "petab" / "problem.yaml") as f: @@ -202,8 +194,16 @@ def test_ude(test): # llh if test in ( + # ?? cases where nn part of observable formula ?? "004", - "016", + "009", + "012", + "013", + "018", + "020", + "022", + "025", + "028", ): with pytest.raises(NotImplementedError): run_simulations(jax_problem) @@ -260,13 +260,9 @@ def test_ude(test): ) else: expected = h5py.File(test_dir / file, "r") - for layer_name, layer in jax_problem.model.nns[ - component - ].layers.items(): + for layer_name, layer in jax_problem.model.nns[component].layers.items(): for attribute in dir(layer): - if not isinstance( - getattr(layer, attribute), jax.numpy.ndarray - ): + if not isinstance(getattr(layer, attribute), jax.numpy.ndarray): continue actual = getattr( sllh.model.nns[component].layers[layer_name], attribute @@ -281,7 +277,7 @@ def test_ude(test): ) np.testing.assert_allclose( actual, - expected[layer_name][attribute][:], + expected["parameters"][component][layer_name][attribute][:], atol=solutions["tol_grad_llh"], rtol=solutions["tol_grad_llh"], ) From d2137c3ae2424bea43c74224af946d336d07caaf Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 1 Sep 2025 16:29:16 +0100 Subject: [PATCH 03/20] update petab_sciml workflow - on branches and sciml install branch --- .github/workflows/test_petab_sciml.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml index 98cc735e7a..25839e6b46 100644 --- a/.github/workflows/test_petab_sciml.yml +++ b/.github/workflows/test_petab_sciml.yml @@ -3,11 +3,12 @@ on: push: branches: - develop - - master + - main pull_request: branches: - - master + - main - develop + - jax_sciml merge_group: workflow_dispatch: @@ -59,7 +60,7 @@ jobs: - name: Download and install PEtab SciML run: | source ./venv/bin/activate \ - && python -m pip install git+https://github.com/sebapersson/petab_sciml.git@unify_data#subdirectory=src/python \ + && python -m pip install git+https://github.com/sebapersson/petab_sciml.git@main#subdirectory=src/python \ - name: Install petab From 75a36308cd4cca10cac9fd953ddc1b03abc84f2f Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 2 Sep 2025 11:37:32 +0100 Subject: [PATCH 04/20] updates to petab sciml workflow --- .github/workflows/test_petab_sciml.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml index 25839e6b46..20ac50a235 100644 --- a/.github/workflows/test_petab_sciml.yml +++ b/.github/workflows/test_petab_sciml.yml @@ -34,6 +34,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 20 + submodules: recursive - name: Install apt dependencies uses: ./.github/actions/install-apt-dependencies @@ -67,7 +68,7 @@ jobs: run: | source ./venv/bin/activate \ && python3 -m pip uninstall -y petab \ - && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@develop \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@sciml \ - name: Run PEtab SciML testsuite run: | From 8a76e4420717e2b7aaa6aa2edb1ee3ad3d04c949 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 2 Sep 2025 15:53:08 +0100 Subject: [PATCH 05/20] fix undef local var in jax tests --- python/sdist/amici/jax/petab.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 86bed469c6..0105d330ba 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -541,7 +541,7 @@ def _get_nominal_parameter_values( ]["array_files"] if "parameters" in h5py.File(file_spec, "r").keys() ] - ) + ) if self._petab_problem.extensions_config else {} nn_input_arrays = dict( [ @@ -551,7 +551,7 @@ def _get_nominal_parameter_values( ]["array_files"] if "inputs" in h5py.File(file_spec, "r").keys() ] - ) + ) if self._petab_problem.extensions_config else {} # extract nominal values from petab problem for pname, row in self._petab_problem.parameter_df.iterrows(): @@ -696,7 +696,9 @@ def _get_hybridization_df(self): ] ] hybridization_df = pd.concat(hybridizations) - return hybridization_df + return hybridization_df + else: + return None @property def parameter_ids(self) -> list[str]: From 374922c0233567407a4abf4bcc51a39a0e364711 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 8 Sep 2025 16:19:41 +0100 Subject: [PATCH 06/20] frozen layers for RHS networks generalise frozen layers to networks across system Use stop_grad instead --- python/sdist/amici/jax/petab.py | 96 +++++++++++++++++++++++++++++---- tests/sciml/test_sciml.py | 19 ++++--- 2 files changed, 97 insertions(+), 18 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 0105d330ba..36ae8d0c12 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -16,6 +16,7 @@ import jaxtyping as jt import jax.lax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np import pandas as pd import petab.v1 as petab @@ -1414,6 +1415,7 @@ def run_simulations( ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), max_steps: int = 2**10, + is_grad_mode: bool = False, ret: ReturnValue | str = ReturnValue.llh, ): """ @@ -1481,16 +1483,39 @@ def run_simulations( for sc in simulation_conditions ] ) - output, results = problem.run_simulations( - dynamic_conditions, - preeq_array, - solver, - controller, - root_finder, - steady_state_event, - max_steps, - ret, - ) + if is_grad_mode: + output, _ = eqx.filter_grad( + grad_filter_run_simulations, has_aux=True + )( + problem, + dynamic_conditions, + preeq_array, + solver, + controller, + root_finder, + steady_state_event, + max_steps, + ret, + ) + results = { + "llh": jnp.array([]), + "stats_dyn": None, + "stats_posteq": None, + "ts": jnp.array([]), + "x": jnp.array([]), + } + + else: + output, results = problem.run_simulations( + dynamic_conditions, + preeq_array, + solver, + controller, + root_finder, + steady_state_event, + max_steps, + ret, + ) else: output = jnp.array(0.0) results = { @@ -1501,7 +1526,7 @@ def run_simulations( "x": jnp.array([]), } - if ret in (ReturnValue.llh, ReturnValue.chi2): + if ret in (ReturnValue.llh, ReturnValue.chi2) and not is_grad_mode: output = jnp.sum(output) return output, results | preresults | conditions @@ -1590,3 +1615,52 @@ def petab_simulate( ) dfs.append(df_sc) return pd.concat(dfs).sort_index() + +def apply_grad_filter(problem: JAXProblem,): + for entity in problem._petab_problem.mapping_df[petab.MODEL_ENTITY_ID]: + if "layer" in entity: + net_id = entity.split(".")[0] + layer_id = re.findall(r"\[(.*?)\]", entity)[0] + array_attr = entity.split(".")[-1] + if array_attr in ("weight", "bias"): + problem = eqx.tree_at( + lambda problem: getattr(problem.model.nns[net_id].layers[layer_id], array_attr), + problem, + replace_fn=lambda array_attr: jax.lax.stop_gradient(array_attr) + ) + else: + problem = eqx.tree_at( + lambda problem: problem.model.nns[net_id].layers[layer_id], + problem, + replace_fn=lambda layer: jax.lax.stop_gradient(layer) + ) + + return problem + +def grad_filter_run_simulations( + problem, + simulation_conditions: list[str], + preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + root_finder: AbstractRootFinder, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], + max_steps: jnp.int_, + ret: ReturnValue = ReturnValue.llh, + ): + problem_grad_filtered = apply_grad_filter(problem) + output, stats = problem_grad_filtered.run_simulations( + simulation_conditions, + preeq_array, + solver, + controller, + root_finder, + steady_state_event, + max_steps, + ret, + ) + output = jnp.sum(output) + + return output, stats diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index a39c5b8c9f..91e44977fb 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -234,11 +234,12 @@ def test_ude(test): ) # gradient - sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)( + sllh, _ = run_simulations( jax_problem, solver=diffrax.Kvaerno5(), controller=diffrax.PIDController(atol=1e-14, rtol=1e-14), max_steps=2**16, + is_grad_mode=True, ) for component, file in solutions["grad_files"].items(): actual_dict = {} @@ -246,6 +247,7 @@ def test_ude(test): expected = pd.read_csv(test_dir / file, sep="\t").set_index( petab.PARAMETER_ID ) + for ip in expected.index: if ip in jax_problem.parameter_ids: actual_dict[ip] = sllh.parameters[ @@ -275,9 +277,12 @@ def test_ude(test): actual.swapaxes(0, 1), axis=tuple(range(2, actual.ndim)), ) - np.testing.assert_allclose( - actual, - expected["parameters"][component][layer_name][attribute][:], - atol=solutions["tol_grad_llh"], - rtol=solutions["tol_grad_llh"], - ) + if np.squeeze(expected["parameters"][component][layer_name][attribute][:]).size == 0: + assert np.all(actual == 0.0) + else: + np.testing.assert_allclose( + np.squeeze(actual), + np.squeeze(expected["parameters"][component][layer_name][attribute][:]), + atol=solutions["tol_grad_llh"], + rtol=solutions["tol_grad_llh"], + ) From c2a386bcd662b18f5b98f4103f18059b7109d550 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Wed, 17 Sep 2025 14:56:07 +0100 Subject: [PATCH 07/20] implement nns in the observable formula --- python/sdist/amici/de_model.py | 41 +++++++++++++++++++++++- python/sdist/amici/petab/petab_import.py | 21 +++++++++++- tests/sciml/test_sciml.py | 15 --------- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 1625fed73d..019518cb0f 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2564,6 +2564,7 @@ def _process_hybridization(self, hybridization: dict) -> None: hybridization information """ added_expressions = False + orig_obs = tuple([s.get_id() for s in self._observables]) for net_id, net in hybridization.items(): if net["static"]: continue # do not integrate into ODEs, handle in amici.jax.petab @@ -2606,6 +2607,7 @@ def _process_hybridization(self, hybridization: dict) -> None: raise ValueError( f"Could not find all output variables for neural network {net_id}" ) + for iout, (out_var, comp) in enumerate(outputs.items()): # remove output from model components if isinstance(comp, Parameter): @@ -2620,8 +2622,10 @@ def _process_hybridization(self, hybridization: dict) -> None: ) # generate dummy Function + # FIXME: not robust to an observable output and a regular output being in the other order + ind = iout + len(net["observable_vars"]) out_val = sp.Function(net_id)( - *[input.get_id() for input in inputs], iout + *[input.get_id() for input in inputs], ind ) # add to the model @@ -2641,6 +2645,41 @@ def _process_hybridization(self, hybridization: dict) -> None: ) ) added_expressions = True + + observables = { + ob_var: comp + for comp in self._components + if (ob_var := str(comp.get_id())) in net["observable_vars"] + # TODO: SYNTAX NEEDS to CHANGE + or (ob_var := str(comp.get_id()) + "_dot") + in net["observable_vars"] + } + if len(observables.keys()) != len(net["observable_vars"]): + raise ValueError( + f"Could not find all observable variables for neural network {net_id}" + ) + + for iout, (ob_var, comp) in enumerate(observables.items()): + if isinstance(comp, Observable): + self._observables.remove(comp) + else: + raise ValueError( + f"{comp.get_name()} ({type(comp)}) is not an observable." + ) + out_val = sp.Function(net_id)( + *[input.get_id() for input in inputs], iout + ) + # add to the model + self.add_component( + Observable( + identifier=comp.get_id(), + name=net_id, + value=out_val, + ) + ) + + new_order = [orig_obs.index(s.get_id()) for s in self._observables] + self._observables = [self._observables[i] for i in new_order] if added_expressions: # toposort expressions diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 8f996cc60b..940d794f71 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -152,6 +152,12 @@ def import_petab_problem( hybridization_table["targetId"], ) ) + observable_mapping = dict( + zip( + petab_problem.observable_df["observableFormula"], + petab_problem.observable_df.index, + ) + ) hybridization = { net_id: { "model": NNModelStandard.load_data( @@ -185,7 +191,20 @@ def import_petab_problem( if model_id.split(".")[1].startswith("output") and petab_id in output_mapping.keys() ], - # ?? static included here ?? and handled later ?? + "observable_vars": [ + observable_mapping[petab_id] + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in observable_mapping.keys() + ], **net_config, } for net_id, net_config in config["neural_nets"].items() diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 91e44977fb..3faf06b3f4 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -193,21 +193,6 @@ def test_ude(test): jax_problem = JAXProblem(jax_model, petab_problem) # llh - if test in ( - # ?? cases where nn part of observable formula ?? - "004", - "009", - "012", - "013", - "018", - "020", - "022", - "025", - "028", - ): - with pytest.raises(NotImplementedError): - run_simulations(jax_problem) - return llh, r = run_simulations(jax_problem) np.testing.assert_allclose( llh, From 8bce41615e1dedb08367b588ed0e122f40b54f8d Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 19 Sep 2025 16:13:49 +0100 Subject: [PATCH 08/20] tidy, refactor, generalise sciml test case implementations --- python/sdist/amici/de_model.py | 28 +-- python/sdist/amici/jax/petab.py | 179 ++++++++++-------- python/sdist/amici/petab/parameter_mapping.py | 13 +- python/sdist/amici/petab/petab_import.py | 78 +++++--- tests/sciml/test_sciml.py | 17 +- 5 files changed, 190 insertions(+), 125 deletions(-) diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 019518cb0f..db343019a4 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2596,7 +2596,7 @@ def _process_hybridization(self, hybridization: dict) -> None: ) outputs = { - out_var: comp + out_var: {"comp": comp, "ind": net["output_vars"][out_var]} for comp in self._components if (out_var := str(comp.get_id())) in net["output_vars"] # TODO: SYNTAX NEEDS to CHANGE @@ -2607,8 +2607,9 @@ def _process_hybridization(self, hybridization: dict) -> None: raise ValueError( f"Could not find all output variables for neural network {net_id}" ) - - for iout, (out_var, comp) in enumerate(outputs.items()): + + for out_var, parts in outputs.items(): + comp = parts["comp"] # remove output from model components if isinstance(comp, Parameter): self._parameters.remove(comp) @@ -2622,10 +2623,8 @@ def _process_hybridization(self, hybridization: dict) -> None: ) # generate dummy Function - # FIXME: not robust to an observable output and a regular output being in the other order - ind = iout + len(net["observable_vars"]) out_val = sp.Function(net_id)( - *[input.get_id() for input in inputs], ind + *[input.get_id() for input in inputs], parts["ind"] ) # add to the model @@ -2645,21 +2644,22 @@ def _process_hybridization(self, hybridization: dict) -> None: ) ) added_expressions = True - + observables = { - ob_var: comp + ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]} for comp in self._components if (ob_var := str(comp.get_id())) in net["observable_vars"] - # TODO: SYNTAX NEEDS to CHANGE - or (ob_var := str(comp.get_id()) + "_dot") - in net["observable_vars"] + # # TODO: SYNTAX NEEDS to CHANGE + # or (ob_var := str(comp.get_id()) + "_dot") + # in net["observable_vars"] } if len(observables.keys()) != len(net["observable_vars"]): raise ValueError( f"Could not find all observable variables for neural network {net_id}" ) - - for iout, (ob_var, comp) in enumerate(observables.items()): + + for ob_var, parts in observables.items(): + comp = parts["comp"] if isinstance(comp, Observable): self._observables.remove(comp) else: @@ -2667,7 +2667,7 @@ def _process_hybridization(self, hybridization: dict) -> None: f"{comp.get_name()} ({type(comp)}) is not an observable." ) out_val = sp.Function(net_id)( - *[input.get_id() for input in inputs], iout + *[input.get_id() for input in inputs], parts["ind"] ) # add to the model self.add_component( diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 36ae8d0c12..79650364fb 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -528,31 +528,41 @@ def _get_nominal_parameter_values( for net_id, nn in model.nns.items() } # load nn parameters from file - # ?? dict similar to the one above ?? {nn_id: {layer_id: {weight, bias}}} ?? - par_arrays = dict( - [ - ( - file_spec.split("_")[0], - h5py.File(file_spec, "r")["parameters"][ - file_spec.split("_")[0] - ], - ) - for file_spec in self._petab_problem.extensions_config[ - "sciml" - ]["array_files"] - if "parameters" in h5py.File(file_spec, "r").keys() - ] - ) if self._petab_problem.extensions_config else {} + par_arrays = ( + dict( + [ + ( + file_spec.split("_")[0], + h5py.File(file_spec, "r")["parameters"][ + file_spec.split("_")[0] + ], + ) + for file_spec in self._petab_problem.extensions_config[ + "sciml" + ]["array_files"] + if "parameters" in h5py.File(file_spec, "r").keys() + ] + ) + if self._petab_problem.extensions_config + else {} + ) - nn_input_arrays = dict( - [ - (file_spec.split("_")[0], h5py.File(file_spec, "r")["inputs"]) - for file_spec in self._petab_problem.extensions_config[ - "sciml" - ]["array_files"] - if "inputs" in h5py.File(file_spec, "r").keys() - ] - ) if self._petab_problem.extensions_config else {} + nn_input_arrays = ( + dict( + [ + ( + file_spec.split("_")[0], + h5py.File(file_spec, "r")["inputs"], + ) + for file_spec in self._petab_problem.extensions_config[ + "sciml" + ]["array_files"] + if "inputs" in h5py.File(file_spec, "r").keys() + ] + ) + if self._petab_problem.extensions_config + else {} + ) # extract nominal values from petab problem for pname, row in self._petab_problem.parameter_df.iterrows(): @@ -561,7 +571,6 @@ def _get_nominal_parameter_values( nn = model_pars[net] scalar = True - # ?? When value is NaN do we want to go to par_arrays ?? if np.isnan(row[petab.NOMINAL_VALUE]): value = par_arrays[net] scalar = False @@ -572,7 +581,6 @@ def _get_nominal_parameter_values( layer_name = pname.split(".")[1] layer = nn[layer_name] if len(pname.split(".")) > 2: - # ?? Recursion needed ?? attribute_name = pname.split(".")[2] to_set.append((layer_name, attribute_name)) else: @@ -632,9 +640,11 @@ def _get_nominal_parameter_values( nn_input_arrays[net_id][input][ k ][:], - dtype=jnp.float64, + dtype=jnp.float64 + if jax.config.jax_enable_x64 + else jnp.float32, ), - ) # ?? hardcoded dtype not ideal ?? could infer from env somehow ?? + ) for k in nn_input_arrays[net_id][ input ].keys() @@ -769,19 +779,22 @@ def _eval_nn(self, output_par: str, condition_id: str): ].split(".")[0] nn = self.model.nns[net_id] + def _is_net_input(model_id): + comps = model_id.split(".") + return comps[0] == net_id and comps[1].startswith("inputs") + model_id_map = ( self._petab_problem.mapping_df[ - self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] - .str.split(".") - .str[0] - == net_id + self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID].apply( + _is_net_input + ) ] .reset_index() .set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID] .to_dict() ) - condition_parameter_map = ( + condition_input_map = ( dict( [ ( @@ -793,10 +806,10 @@ def _eval_nn(self, output_par: str, condition_id: str): petab.NOMINAL_VALUE, ], ) - if "input" - in self._petab_problem.condition_df.loc[ + if self._petab_problem.condition_df.loc[ condition_id, petab_id ] + in self._petab_problem.parameter_df.index else ( petab_id, np.float64( @@ -805,9 +818,7 @@ def _eval_nn(self, output_par: str, condition_id: str): ] ), ) - for petab_id in [ - s for s in model_id_map.values() if "input" in s - ] + for petab_id in model_id_map.values() ] ) if not self._petab_problem.condition_df.empty @@ -822,7 +833,28 @@ def _eval_nn(self, output_par: str, condition_id: str): ] ) - # ?? conditional nightmare ?? refactor ?? + # handle conditions + if len(condition_input_map) > 0: + net_input = jnp.array( + [ + condition_input_map[petab_id] + for _, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + + # handle array inputs + if isinstance(self.model.nns[net_id].inputs, dict): + net_input = jnp.array( + [ + self.model.nns[net_id].inputs[petab_id][condition_id] + if condition_id in self.model.nns[net_id].inputs[petab_id] + else self.model.nns[net_id].inputs[petab_id]["0"] + for _, petab_id in model_id_map.items() + ] + ) + return nn.forward(net_input).squeeze() + net_input = jnp.array( [ jax.lax.stop_gradient(self.model.nns[net_id][model_id]) @@ -833,20 +865,10 @@ def _eval_nn(self, output_par: str, condition_id: str): petab_id, petab.NOMINAL_VALUE ] if petab_id in set(self._petab_problem.parameter_df.index) - else self.model.nns[net_id].inputs[petab_id][condition_id] - if isinstance(self.model.nns[net_id].inputs, dict) - and condition_id in self.model.nns[net_id].inputs[petab_id] - else self.model.nns[net_id].inputs[petab_id][ - "0" - ] # ?? "0" always the key if inputs for all conditions ?? - if petab_id in self.model.nns[net_id].inputs else self._petab_problem.parameter_df.loc[ hybridization_parameter_map[petab_id], petab.NOMINAL_VALUE ] - if self._petab_problem.condition_df.empty - else condition_parameter_map[petab_id] for model_id, petab_id in model_id_map.items() - if model_id.split(".")[1].startswith("input") ] ) return nn.forward(net_input).squeeze() @@ -861,7 +883,6 @@ def _map_model_parameter_value( if pval in self.nn_output_ids: nn_output = self._eval_nn(pval, condition_id) if nn_output.size > 1: - # ?? can this approach work for single dimension return ?? maybe remove the squeeze from _eval_nn ?? entityId = self._petab_problem.mapping_df.loc[ pval, petab.MODEL_ENTITY_ID ] @@ -1119,8 +1140,6 @@ def _prepare_conditions( in_axes={ "max_steps": None, "self": None, - "init_override": None, # ?? performance hit ?? flip arrays to avoid ?? - "init_override_mask": None, }, # only list arguments here where eqx.is_array(0) is not the right thing ) def run_simulation( @@ -1135,8 +1154,8 @@ def run_simulation( nps: jt.Float[jt.Array, "nt *nnp"], # noqa: F821, F722 mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 - init_override: jt.Float[jt.Array, "np"], # ?? what do these annotations mean ?? - init_override_mask: jt.Bool[jt.Array, "np"], + init_override: jt.Float[jt.Array, "nx"], # noqa: F821, F722 + init_override_mask: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, root_finder: AbstractRootFinder, @@ -1203,7 +1222,7 @@ def run_simulation( mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, init_override=init_override, - init_override_mask=init_override_mask, + init_override_mask=jax.lax.stop_gradient(init_override_mask), ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)), solver=solver, controller=controller, @@ -1263,7 +1282,7 @@ def run_simulations( self._np_indices, ) ) - # state values override? + init_override_mask = jnp.stack( [ jnp.array( @@ -1279,8 +1298,12 @@ def run_simulations( ] ) if not init_override_mask.any(): - init_override_mask = jnp.array([]) - init_override = jnp.array([]) + init_override_mask = jnp.stack( + [jnp.array([]) for _ in simulation_conditions] + ) + init_override = jnp.stack( + [jnp.array([]) for _ in simulation_conditions] + ) else: init_override = jnp.stack( [ @@ -1293,7 +1316,7 @@ def run_simulations( in set( self._parameter_mappings[sc].map_sim_var.keys() ) - else 1.0 # ?? dummy value - shouldn't matter ?? + else 1.0 for p in self.model.state_ids ] ) @@ -1616,7 +1639,10 @@ def petab_simulate( dfs.append(df_sc) return pd.concat(dfs).sort_index() -def apply_grad_filter(problem: JAXProblem,): + +def apply_grad_filter( + problem: JAXProblem, +): for entity in problem._petab_problem.mapping_df[petab.MODEL_ENTITY_ID]: if "layer" in entity: net_id = entity.split(".")[0] @@ -1624,32 +1650,35 @@ def apply_grad_filter(problem: JAXProblem,): array_attr = entity.split(".")[-1] if array_attr in ("weight", "bias"): problem = eqx.tree_at( - lambda problem: getattr(problem.model.nns[net_id].layers[layer_id], array_attr), + lambda problem: getattr( + problem.model.nns[net_id].layers[layer_id], array_attr + ), problem, - replace_fn=lambda array_attr: jax.lax.stop_gradient(array_attr) + replace_fn=lambda array_attr: jax.lax.stop_gradient( + array_attr + ), ) else: problem = eqx.tree_at( lambda problem: problem.model.nns[net_id].layers[layer_id], problem, - replace_fn=lambda layer: jax.lax.stop_gradient(layer) + replace_fn=lambda layer: jax.lax.stop_gradient(layer), ) return problem + def grad_filter_run_simulations( - problem, - simulation_conditions: list[str], - preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 - solver: diffrax.AbstractSolver, - controller: diffrax.AbstractStepSizeController, - root_finder: AbstractRootFinder, - steady_state_event: Callable[ - ..., diffrax._custom_types.BoolScalarLike - ], - max_steps: jnp.int_, - ret: ReturnValue = ReturnValue.llh, - ): + problem, + simulation_conditions: list[str], + preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + root_finder: AbstractRootFinder, + steady_state_event: Callable[..., diffrax._custom_types.BoolScalarLike], + max_steps: jnp.int_, + ret: ReturnValue = ReturnValue.llh, +): problem_grad_filtered = apply_grad_filter(problem) output, stats = problem_grad_filtered.run_simulations( simulation_conditions, diff --git a/python/sdist/amici/petab/parameter_mapping.py b/python/sdist/amici/petab/parameter_mapping.py index 7264ea3a14..0f70ce0d71 100644 --- a/python/sdist/amici/petab/parameter_mapping.py +++ b/python/sdist/amici/petab/parameter_mapping.py @@ -378,13 +378,20 @@ def create_parameter_mapping( if parameter_mapping_kwargs is None: parameter_mapping_kwargs = {} + # TODO: Add support for conditions with sciml mappings in petab library + mapping = ( + None + if "sciml" in petab_problem.extensions_config + else petab_problem.mapping_df + ) + prelim_parameter_mapping = ( petab.get_optimization_to_simulation_parameter_mapping( condition_df=petab_problem.condition_df, measurement_df=petab_problem.measurement_df, parameter_df=petab_problem.parameter_df, observable_df=petab_problem.observable_df, - # mapping_df=petab_problem.mapping_df, + mapping_df=mapping, model=petab_problem.model, simulation_conditions=simulation_conditions, fill_fixed_parameters=fill_fixed_parameters, @@ -394,8 +401,6 @@ def create_parameter_mapping( ) ) - # ?? put mappings in later ?? after mapping for condition ?? will there be a performance regression as a result ?? - parameter_mapping = ParameterMapping() for (_, condition), prelim_mapping_for_condition in zip( simulation_conditions.iterrows(), prelim_parameter_mapping, strict=True @@ -587,8 +592,6 @@ def create_parameter_mapping_for_condition( ) logger.debug(f"Merged: {condition_map_sim_var}") - # ?? right place for static hybridization here ?? - if "sciml" in petab_problem.extensions_config: hybridizations = [ pd.read_csv(hf, sep="\t") diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 940d794f71..5ae20cd4c3 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -9,6 +9,11 @@ import os import shutil from pathlib import Path +<<<<<<< HEAD +======= +from warnings import warn +import re +>>>>>>> 05f22cc9 (tidy, refactor, generalise sciml test case implementations) import amici import pandas as pd @@ -177,34 +182,44 @@ def import_petab_problem( if model_id.split(".")[1].startswith("input") and petab_id in input_mapping.keys() ], - "output_vars": [ - output_mapping[petab_id] - for petab_id, model_id in petab_problem.mapping_df.loc[ - petab_problem.mapping_df[petab.MODEL_ENTITY_ID] - .str.split(".") - .str[0] - == net_id, - petab.MODEL_ENTITY_ID, + "output_vars": dict( + [ + ( + output_mapping[petab_id], + _get_net_index(model_id), + ) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in output_mapping.keys() ] - .to_dict() - .items() - if model_id.split(".")[1].startswith("output") - and petab_id in output_mapping.keys() - ], - "observable_vars": [ - observable_mapping[petab_id] - for petab_id, model_id in petab_problem.mapping_df.loc[ - petab_problem.mapping_df[petab.MODEL_ENTITY_ID] - .str.split(".") - .str[0] - == net_id, - petab.MODEL_ENTITY_ID, + ), + "observable_vars": dict( + [ + ( + observable_mapping[petab_id], + _get_net_index(model_id), + ) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + and petab_id in observable_mapping.keys() ] - .to_dict() - .items() - if model_id.split(".")[1].startswith("output") - and petab_id in observable_mapping.keys() - ], + ), **net_config, } for net_id, net_config in config["neural_nets"].items() @@ -258,3 +273,14 @@ def import_petab_problem( ) return model + + +def _get_net_index(model_id: str): + matches = re.findall(r"\[(\d+)\]", model_id) + if matches: + return int(matches[-1]) + return None + + +# for backwards compatibility +import_model = import_model_sbml diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 3faf06b3f4..80f0881e55 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -120,7 +120,6 @@ def test_net(test): and hasattr(net.layers[layer], "weight") and net.layers[layer].weight is not None ): - # ?? grabbing weights from the parameters file ?? need to check if they are present in above if condition ?? w = par["parameters"][ml_model.nn_model_id][layer]["weight"][:] if isinstance(net.layers[layer], eqx.nn.ConvTranspose): # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose @@ -136,7 +135,6 @@ def test_net(test): and hasattr(net.layers[layer], "bias") and net.layers[layer].bias is not None ): - # ?? grabbing biases from the parameters file ?? need to check if they are present in above if condition ?? b = par["parameters"][ml_model.nn_model_id][layer]["bias"][:] if isinstance( net.layers[layer], @@ -232,7 +230,7 @@ def test_ude(test): expected = pd.read_csv(test_dir / file, sep="\t").set_index( petab.PARAMETER_ID ) - + for ip in expected.index: if ip in jax_problem.parameter_ids: actual_dict[ip] = sllh.parameters[ @@ -262,12 +260,21 @@ def test_ude(test): actual.swapaxes(0, 1), axis=tuple(range(2, actual.ndim)), ) - if np.squeeze(expected["parameters"][component][layer_name][attribute][:]).size == 0: + if ( + np.squeeze( + expected["parameters"][component][layer_name][attribute][:] + ).size + == 0 + ): assert np.all(actual == 0.0) else: np.testing.assert_allclose( np.squeeze(actual), - np.squeeze(expected["parameters"][component][layer_name][attribute][:]), + np.squeeze( + expected["parameters"][component][layer_name][ + attribute + ][:] + ), atol=solutions["tol_grad_llh"], rtol=solutions["tol_grad_llh"], ) From b3225552b2dd88bcb3d16eb7424146767d3df0df Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 30 Sep 2025 12:03:35 +0100 Subject: [PATCH 09/20] hybridization df in _petab_problem - makes JAXProblem jit-able --- python/sdist/amici/jax/petab.py | 100 ++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 79650364fb..e160b568d7 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -7,6 +7,7 @@ from pathlib import Path from collections.abc import Callable import logging +from typing import Union import diffrax @@ -76,6 +77,30 @@ def jax_unscale( return jnp.power(10, parameter) raise ValueError(f"Invalid parameter scaling: {scale_str}") +# IDEA: Implement hybridization_df in petab.v2.Problem instead? Then class here could be removed +class HybridProblem(petab.Problem): + hybridization_df: pd.DataFrame + + def __init__(self, petab_problem: petab.Problem): + self.__dict__.update(petab_problem.__dict__) + self.hybridization_df = _get_hybridization_df(petab_problem) + + +def _get_hybridization_df(petab_problem): + if "sciml" in petab_problem.extensions_config: + hybridizations = [ + pd.read_csv(hf, sep="\t", index_col=0) + for hf in petab_problem.extensions_config["sciml"][ + "hybridization_files" + ] + ] + hybridization_df = pd.concat(hybridizations) + return hybridization_df + else: + return None + +def _get_hybrid_petab_problem(petab_problem: petab.Problem): + return HybridProblem(petab_problem) class JAXProblem(eqx.Module): """ @@ -97,7 +122,6 @@ class JAXProblem(eqx.Module): model: JAXModel simulation_conditions: tuple[tuple[str, ...], ...] _parameter_mappings: dict[str, ParameterMappingForCondition] - _hybridization_df: pd.DataFrame _ts_dyn: np.ndarray _ts_posteq: np.ndarray _my: np.ndarray @@ -111,7 +135,7 @@ class JAXProblem(eqx.Module): _np_mask: np.ndarray _np_indices: np.ndarray _petab_measurement_indices: np.ndarray - _petab_problem: petab.Problem + _petab_problem: petab.Problem | HybridProblem def __init__(self, model: JAXModel, petab_problem: petab.Problem): """ @@ -124,8 +148,7 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): """ scs = petab_problem.get_simulation_conditions_from_measurement_df() self.simulation_conditions = tuple(tuple(sc) for sc in scs.values) - self._petab_problem = petab_problem - self._hybridization_df = self._get_hybridization_df() + self._petab_problem = _get_hybrid_petab_problem(petab_problem) self.parameters, self.model = self._get_nominal_parameter_values(model) self._parameter_mappings = self._get_parameter_mappings(scs) ( @@ -698,19 +721,6 @@ def _get_inputs(self): ].values.reshape(shape) return inputs - def _get_hybridization_df(self): - if "sciml" in self._petab_problem.extensions_config: - hybridizations = [ - pd.read_csv(hf, sep="\t", index_col=0) - for hf in self._petab_problem.extensions_config["sciml"][ - "hybridization_files" - ] - ] - hybridization_df = pd.concat(hybridizations) - return hybridization_df - else: - return None - @property def parameter_ids(self) -> list[str]: """ @@ -827,9 +837,9 @@ def _is_net_input(model_id): hybridization_parameter_map = dict( [ - (petab_id, self._hybridization_df.loc[petab_id, "targetValue"]) + (petab_id, self._petab_problem.hybridization_df.loc[petab_id, "targetValue"]) for petab_id in model_id_map.values() - if petab_id in set(self._hybridization_df.index) + if petab_id in set(self._petab_problem.hybridization_df.index) ] ) @@ -1297,32 +1307,32 @@ def run_simulations( for sc in simulation_conditions ] ) - if not init_override_mask.any(): - init_override_mask = jnp.stack( - [jnp.array([]) for _ in simulation_conditions] - ) - init_override = jnp.stack( - [jnp.array([]) for _ in simulation_conditions] - ) - else: - init_override = jnp.stack( - [ - jnp.array( - [ - self._eval_nn( - self._parameter_mappings[sc].map_sim_var[p], sc - ) - if p - in set( - self._parameter_mappings[sc].map_sim_var.keys() - ) - else 1.0 - for p in self.model.state_ids - ] - ) - for sc in simulation_conditions - ] - ) + # if init_override_mask.sum() == 0.0: + # init_override_mask = jnp.stack( + # [jnp.array([]) for _ in simulation_conditions] + # ) + # init_override = jnp.stack( + # [jnp.array([]) for _ in simulation_conditions] + # ) + # else: + init_override = jnp.stack( + [ + jnp.array( + [ + self._eval_nn( + self._parameter_mappings[sc].map_sim_var[p], sc + ) + if p + in set( + self._parameter_mappings[sc].map_sim_var.keys() + ) + else 1.0 + for p in self.model.state_ids + ] + ) + for sc in simulation_conditions + ] + ) return self.run_simulation( p_array, From dddc4c2285454d0c858af50baf800e51daa5647e Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 3 Oct 2025 16:57:47 +0100 Subject: [PATCH 10/20] update frozen layer/arrays implementation --- python/sdist/amici/jax/nn.py | 22 +++- python/sdist/amici/jax/ode_export.py | 3 +- python/sdist/amici/jax/petab.py | 127 ++++------------------- python/sdist/amici/petab/petab_import.py | 27 +++++ tests/sciml/test_sciml.py | 3 +- 5 files changed, 72 insertions(+), 110 deletions(-) diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py index d409063c20..f94a0dcca4 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/jax/nn.py @@ -30,7 +30,7 @@ def tanhshrink(x: jnp.ndarray) -> jnp.ndarray: return x - jnp.tanh(x) -def generate_equinox(nn_model: "NNModel", filename: Path | str): # noqa: F821 +def generate_equinox(nn_model: "NNModel", filename: Path | str, frozen_layers: dict = {}): # noqa: F821 # TODO: move to top level import and replace forward type definitions from petab_sciml import Layer @@ -53,6 +53,7 @@ def generate_equinox(nn_model: "NNModel", filename: Path | str): # noqa: F821 _generate_forward( node, node_indent, + frozen_layers, layers.get( node.target, Layer(layer_id="dummy", layer_type="Linear"), @@ -149,13 +150,25 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F return f"{' ' * indent}'{layer.layer_id}': {layer_str}" -def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F821 +def _generate_forward(node: "Node", indent, frozen_layers: dict = {}, layer_type=str) -> str: # noqa: F821 if node.op == "placeholder": # TODO: inconsistent target vs name return f"{' ' * indent}{node.name} = input" if node.op == "call_module": fun_str = f"self.layers['{node.target}']" + if node.name in frozen_layers: + if frozen_layers[node.name]: + arr_attr = frozen_layers[node.name] + get_lambda = f"lambda layer: getattr(layer, '{arr_attr}')" + replacer = ( + "replace_fn = lambda arr: jax.lax.stop_gradient(arr)" + ) + tree_string = f"tree_{node.name} = eqx.tree_at({get_lambda}, {fun_str}, {replacer})" + fun_str = f"tree_{node.name}" + else: + fun_str = f"jax.lax.stop_gradient({fun_str})" + tree_string = "" if layer_type.startswith(("Conv", "Linear", "LayerNorm")): if layer_type in ("LayerNorm",): dims = f"len({fun_str}.shape)+1" @@ -198,6 +211,9 @@ def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F82 kwargs += ["key=key"] kwargs_str = ", ".join(kwargs) if node.op in ("call_module", "call_function", "call_method"): - return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" + if node.name in frozen_layers: + return f"{' ' * indent}{tree_string}\n{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" + else: + return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" if node.op == "output": return f"{' ' * indent}{node.target} = {args}" diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 53b8479155..69cc413c73 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -6,7 +6,7 @@ The user generally won't have to directly call any function from this module as this will be done by :py:func:`amici.pysb_import.pysb2jax`, -:py:func:`amici.sbml_import.SbmlImporter.sbml2jax` and +:py:func:`amici.sbml_import.SbmlImporter.` and :py:func:`amici.petab_import.import_model`. """ @@ -288,6 +288,7 @@ def _generate_nn_code(self) -> None: generate_equinox( net["model"], self.model_path / f"{net_name}.py", + net["frozen_layers"], ) def _implicit_roots(self) -> list[sp.Expr]: diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index e160b568d7..b649071f80 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -77,7 +77,8 @@ def jax_unscale( return jnp.power(10, parameter) raise ValueError(f"Invalid parameter scaling: {scale_str}") -# IDEA: Implement hybridization_df in petab.v2.Problem instead? Then class here could be removed + +# IDEA: Implement this class in petab-sciml instead? class HybridProblem(petab.Problem): hybridization_df: pd.DataFrame @@ -98,10 +99,12 @@ def _get_hybridization_df(petab_problem): return hybridization_df else: return None - + + def _get_hybrid_petab_problem(petab_problem: petab.Problem): return HybridProblem(petab_problem) + class JAXProblem(eqx.Module): """ PEtab problem wrapper for JAX models. @@ -837,7 +840,12 @@ def _is_net_input(model_id): hybridization_parameter_map = dict( [ - (petab_id, self._petab_problem.hybridization_df.loc[petab_id, "targetValue"]) + ( + petab_id, + self._petab_problem.hybridization_df.loc[ + petab_id, "targetValue" + ], + ) for petab_id in model_id_map.values() if petab_id in set(self._petab_problem.hybridization_df.index) ] @@ -1307,14 +1315,6 @@ def run_simulations( for sc in simulation_conditions ] ) - # if init_override_mask.sum() == 0.0: - # init_override_mask = jnp.stack( - # [jnp.array([]) for _ in simulation_conditions] - # ) - # init_override = jnp.stack( - # [jnp.array([]) for _ in simulation_conditions] - # ) - # else: init_override = jnp.stack( [ jnp.array( @@ -1323,9 +1323,7 @@ def run_simulations( self._parameter_mappings[sc].map_sim_var[p], sc ) if p - in set( - self._parameter_mappings[sc].map_sim_var.keys() - ) + in set(self._parameter_mappings[sc].map_sim_var.keys()) else 1.0 for p in self.model.state_ids ] @@ -1448,7 +1446,6 @@ def run_simulations( ..., diffrax._custom_types.BoolScalarLike ] = diffrax.steady_state_event(), max_steps: int = 2**10, - is_grad_mode: bool = False, ret: ReturnValue | str = ReturnValue.llh, ): """ @@ -1516,39 +1513,16 @@ def run_simulations( for sc in simulation_conditions ] ) - if is_grad_mode: - output, _ = eqx.filter_grad( - grad_filter_run_simulations, has_aux=True - )( - problem, - dynamic_conditions, - preeq_array, - solver, - controller, - root_finder, - steady_state_event, - max_steps, - ret, - ) - results = { - "llh": jnp.array([]), - "stats_dyn": None, - "stats_posteq": None, - "ts": jnp.array([]), - "x": jnp.array([]), - } - - else: - output, results = problem.run_simulations( - dynamic_conditions, - preeq_array, - solver, - controller, - root_finder, - steady_state_event, - max_steps, - ret, - ) + output, results = problem.run_simulations( + dynamic_conditions, + preeq_array, + solver, + controller, + root_finder, + steady_state_event, + max_steps, + ret, + ) else: output = jnp.array(0.0) results = { @@ -1559,7 +1533,7 @@ def run_simulations( "x": jnp.array([]), } - if ret in (ReturnValue.llh, ReturnValue.chi2) and not is_grad_mode: + if ret in (ReturnValue.llh, ReturnValue.chi2): output = jnp.sum(output) return output, results | preresults | conditions @@ -1648,58 +1622,3 @@ def petab_simulate( ) dfs.append(df_sc) return pd.concat(dfs).sort_index() - - -def apply_grad_filter( - problem: JAXProblem, -): - for entity in problem._petab_problem.mapping_df[petab.MODEL_ENTITY_ID]: - if "layer" in entity: - net_id = entity.split(".")[0] - layer_id = re.findall(r"\[(.*?)\]", entity)[0] - array_attr = entity.split(".")[-1] - if array_attr in ("weight", "bias"): - problem = eqx.tree_at( - lambda problem: getattr( - problem.model.nns[net_id].layers[layer_id], array_attr - ), - problem, - replace_fn=lambda array_attr: jax.lax.stop_gradient( - array_attr - ), - ) - else: - problem = eqx.tree_at( - lambda problem: problem.model.nns[net_id].layers[layer_id], - problem, - replace_fn=lambda layer: jax.lax.stop_gradient(layer), - ) - - return problem - - -def grad_filter_run_simulations( - problem, - simulation_conditions: list[str], - preeq_array: jt.Float[jt.Array, "ncond *nx"], # noqa: F821, F722 - solver: diffrax.AbstractSolver, - controller: diffrax.AbstractStepSizeController, - root_finder: AbstractRootFinder, - steady_state_event: Callable[..., diffrax._custom_types.BoolScalarLike], - max_steps: jnp.int_, - ret: ReturnValue = ReturnValue.llh, -): - problem_grad_filtered = apply_grad_filter(problem) - output, stats = problem_grad_filtered.run_simulations( - simulation_conditions, - preeq_array, - solver, - controller, - root_finder, - steady_state_event, - max_steps, - ret, - ) - output = jnp.sum(output) - - return output, stats diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 5ae20cd4c3..438256fee2 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -220,6 +220,25 @@ def import_petab_problem( and petab_id in observable_mapping.keys() ] ), + "frozen_layers": dict( + [ + _get_frozen_layers(model_id) + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if petab_id in petab_problem.parameter_df.index + and petab_problem.parameter_df.loc[ + petab_id, petab.ESTIMATE + ] + == 0 + ] + ), **net_config, } for net_id, net_config in config["neural_nets"].items() @@ -282,5 +301,13 @@ def _get_net_index(model_id: str): return None +def _get_frozen_layers(model_id): + layers = re.findall(r"\[(.*?)\]", model_id) + array_attr = model_id.split(".")[-1] + layer_id = layers[0] if len(layers) else None + array_attr = array_attr if array_attr in ("weight", "bias") else None + return layer_id, array_attr + + # for backwards compatibility import_model = import_model_sbml diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 80f0881e55..3f00a5747d 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -217,12 +217,11 @@ def test_ude(test): ) # gradient - sllh, _ = run_simulations( + sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)( jax_problem, solver=diffrax.Kvaerno5(), controller=diffrax.PIDController(atol=1e-14, rtol=1e-14), max_steps=2**16, - is_grad_mode=True, ) for component, file in solutions["grad_files"].items(): actual_dict = {} From 5f5fb8c7de497b946c0623129d7e319fd9e1ed35 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 6 Oct 2025 11:18:37 +0100 Subject: [PATCH 11/20] update jax petab notebook --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 983e139237..2e16bf7d50 100644 --- a/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -82,7 +82,9 @@ "cell_type": "markdown", "id": "415962751301c64a", "metadata": {}, - "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." + "source": [ + "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." + ] }, { "cell_type": "code", @@ -94,7 +96,9 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Access the results for the specified condition\n", - "results[simulation_condition]" + "ic = results[\"simulation_conditions\"].index(simulation_condition)\n", + "print(\"llh: \", results[\"llh\"][ic])\n", + "print(\"state variables: \", results[\"x\"][ic, :])" ] }, { @@ -129,7 +133,9 @@ "cell_type": "markdown", "id": "fe4d3b40ee3efdf2", "metadata": {}, - "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." + "source": [ + "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." + ] }, { "cell_type": "code", @@ -180,7 +186,9 @@ "cell_type": "markdown", "id": "4fa97c33719c2277", "metadata": {}, - "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." + "source": [ + "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." + ] }, { "cell_type": "code", @@ -283,7 +291,9 @@ "cell_type": "markdown", "id": "dc9bc07cde00a926", "metadata": {}, - "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." + "source": [ + "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." + ] }, { "cell_type": "code", @@ -302,7 +312,9 @@ "cell_type": "markdown", "id": "851c3ec94cb5d086", "metadata": {}, - "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." + "source": [ + "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." + ] }, { "cell_type": "code", @@ -318,7 +330,9 @@ "cell_type": "markdown", "id": "375b835fecc5a022", "metadata": {}, - "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." + "source": [ + "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." + ] }, { "cell_type": "code", @@ -334,7 +348,9 @@ "cell_type": "markdown", "id": "8eb7cc3db510c826", "metadata": {}, - "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." + "source": [ + "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." + ] }, { "cell_type": "code", @@ -342,14 +358,16 @@ "metadata": {}, "outputs": [], "source": [ - "grad._measurements[simulation_condition]" + "grad._my[ic, :]" ] }, { "cell_type": "markdown", "id": "58eb04393a1463d", "metadata": {}, - "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." + "source": [ + "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." + ] }, { "cell_type": "code", @@ -671,8 +689,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.0" + "pygments_lexer": "ipython3" } }, "nbformat": 4, From ddc68fa785e6dc5f5af532f1e594202ecd53767e Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Mon, 6 Oct 2025 11:27:49 +0100 Subject: [PATCH 12/20] add h5py to docs deps --- doc/rtd_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index cbb21058c2..235b11d00e 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -16,6 +16,7 @@ sphinx_rtd_theme>=1.2.0 petab[vis]>=0.2.0 sphinx-autodoc-typehints ipython>=8.13.2 +h5py>=3.14.0 breathe>=4.35.0 exhale>=0.3.7 -e git+https://github.com/mithro/sphinx-contrib-mithro#egg=sphinx-contrib-exhale-multiproject&subdirectory=sphinx-contrib-exhale-multiproject From 9e744c7afa30f1032d204685a2cdb01416bbdd67 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 7 Oct 2025 10:30:42 +0100 Subject: [PATCH 13/20] fix sbml jax tests --- python/sdist/amici/jax/jaxcodeprinter.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/jaxcodeprinter.py b/python/sdist/amici/jax/jaxcodeprinter.py index f96186ac3b..ed022e9fd6 100644 --- a/python/sdist/amici/jax/jaxcodeprinter.py +++ b/python/sdist/amici/jax/jaxcodeprinter.py @@ -6,6 +6,7 @@ import sympy as sp from sympy.printing.numpy import NumPyPrinter +from sympy.core.function import UndefinedFunction def _jnp_array_str(array) -> str: @@ -42,8 +43,11 @@ def _print_Mul(self, expr: sp.Expr) -> str: return super()._print_Mul(expr) return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})" - def _print_Function(self, expr): - return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]" + def _print_Function(self, expr: sp.Expr) -> str: + if isinstance(expr.func, UndefinedFunction): + return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]" + else: + return super()._print_Function(expr) def _print_Max(self, expr: sp.Expr) -> str: """ From 021ea450899b7846752afea99a0eb100908ed3bc Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 7 Oct 2025 14:58:55 +0100 Subject: [PATCH 14/20] missed rebased imports --- python/sdist/amici/petab/petab_import.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 438256fee2..486c34c432 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -9,11 +9,7 @@ import os import shutil from pathlib import Path -<<<<<<< HEAD -======= -from warnings import warn import re ->>>>>>> 05f22cc9 (tidy, refactor, generalise sciml test case implementations) import amici import pandas as pd From e630100e7f7a38e33ded5e2f3405576a5f960a15 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 7 Oct 2025 16:39:16 +0100 Subject: [PATCH 15/20] codecov maybe --- python/sdist/amici/jax/petab.py | 2 -- python/sdist/amici/petab/petab_import.py | 1 - 2 files changed, 3 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index b649071f80..dde7ea1c8c 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -97,8 +97,6 @@ def _get_hybridization_df(petab_problem): ] hybridization_df = pd.concat(hybridizations) return hybridization_df - else: - return None def _get_hybrid_petab_problem(petab_problem: petab.Problem): diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 486c34c432..0317c14256 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -294,7 +294,6 @@ def _get_net_index(model_id: str): matches = re.findall(r"\[(\d+)\]", model_id) if matches: return int(matches[-1]) - return None def _get_frozen_layers(model_id): From 08e58a03222e0c2de9b4291b56cca57bb9759dc7 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 10 Oct 2025 08:56:31 +0100 Subject: [PATCH 16/20] codecov - update cov file name --- .github/workflows/test_petab_sciml.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml index 20ac50a235..23f182012a 100644 --- a/.github/workflows/test_petab_sciml.yml +++ b/.github/workflows/test_petab_sciml.yml @@ -73,7 +73,7 @@ jobs: - name: Run PEtab SciML testsuite run: | source ./venv/bin/activate \ - && pytest --cov-report=xml:coverage.xml \ + && pytest --cov-report=xml:coverage_petab_sciml.xml \ --cov=./ tests/sciml/test_sciml.py - name: Codecov @@ -81,6 +81,6 @@ jobs: uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} - file: coverage.xml - flags: petab + file: coverage_petab_sciml.xml + flags: petab_sciml fail_ci_if_error: true From 55363e191a0212772c136eb1e9005e60913b50a8 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 10 Oct 2025 10:38:58 +0100 Subject: [PATCH 17/20] codecov - specify cov path --- .github/workflows/test_petab_sciml.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml index 23f182012a..c01af88739 100644 --- a/.github/workflows/test_petab_sciml.yml +++ b/.github/workflows/test_petab_sciml.yml @@ -74,7 +74,7 @@ jobs: run: | source ./venv/bin/activate \ && pytest --cov-report=xml:coverage_petab_sciml.xml \ - --cov=./ tests/sciml/test_sciml.py + --cov=amici tests/sciml/test_sciml.py - name: Codecov if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev' From 6782793fa250e3fa8a486503dc9d3695c74e87c6 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 14 Oct 2025 11:11:05 +0100 Subject: [PATCH 18/20] enable zero params case --- python/sdist/amici/jax/petab.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index dde7ea1c8c..481643a95f 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1104,17 +1104,21 @@ def _prepare_conditions( p_array = jnp.stack( [self.load_model_parameters(sc) for sc in conditions] ) - unscaled_parameters = jnp.stack( - [ - jax_unscale( - self.parameters[ip], - self._petab_problem.parameter_df.loc[ - p_id, petab.PARAMETER_SCALE - ], - ) - for ip, p_id in enumerate(self.parameter_ids) - ] - ) + + if self.parameters.size: + unscaled_parameters = jnp.stack( + [ + jax_unscale( + self.parameters[ip], + self._petab_problem.parameter_df.loc[ + p_id, petab.PARAMETER_SCALE + ], + ) + for ip, p_id in enumerate(self.parameter_ids) + ] + ) + else: + unscaled_parameters = jnp.zeros((*self._ts_masks.shape[:2], 0)) if op_numeric is not None and op_numeric.size: op_array = jnp.where( From 3ad23096edb39aeb7b4c34f261c3dd166cff2946 Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 14 Oct 2025 12:21:31 +0100 Subject: [PATCH 19/20] doc build forward type definition workaround --- doc/conf.py | 5 +++++ doc/rtd_requirements.txt | 2 ++ 2 files changed, 7 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index 78e8534768..99f1396c0f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -33,6 +33,7 @@ import amici import pandas as pd # noqa: F401 import sympy as sp # noqa: F401 +import warnings def install_doxygen(): @@ -364,6 +365,10 @@ def install_doxygen(): "ExpDataPtrVector": ":class:`amici.amici.ExpData`", } +# TODO: alias for forward type definition, remove after release of petab_sciml +autodoc_type_aliases = { + "NNModel": "petab_sciml.NNModel", +} def process_docstring(app, what, name, obj, options, lines): # only apply in the amici.amici module diff --git a/doc/rtd_requirements.txt b/doc/rtd_requirements.txt index 235b11d00e..da9ccf87ba 100644 --- a/doc/rtd_requirements.txt +++ b/doc/rtd_requirements.txt @@ -7,6 +7,8 @@ setuptools>=67.7.2 # https://github.com/pysb/pysb/pull/599 # for building the documentation, we don't care whether this fully works git+https://github.com/pysb/pysb@0afeaab385e9a1d813ecf6fdaf0153f4b91358af +# For forward type definition in generate_equinox +git+https://github.com/PEtab-dev/petab_sciml.git@727d177fd3f85509d0bdcc278b672e9eeafd2384#subdirectory=src/python matplotlib>=3.7.1 optax nbsphinx From 23227aa2eca96b8d615314f2d3aae77652f354ac Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Tue, 14 Oct 2025 16:17:42 +0100 Subject: [PATCH 20/20] simplify array input processing --- python/sdist/amici/jax/petab.py | 38 +++++++++++---------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 481643a95f..520c141d9c 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -651,32 +651,18 @@ def _get_nominal_parameter_values( # set inputs in the model if provided if len(nn_input_arrays) > 0: for net_id in model_pars: - for input in model.nns[net_id].inputs: - input_array = dict( - [ - ( - input, - dict( - [ - ( - k, - jnp.array( - nn_input_arrays[net_id][input][ - k - ][:], - dtype=jnp.float64 - if jax.config.jax_enable_x64 - else jnp.float32, - ), - ) - for k in nn_input_arrays[net_id][ - input - ].keys() - ] - ), - ) - ] - ) + input_array = { + input: { + k: jnp.array( + arr[:], + dtype=jnp.float64 + if jax.config.jax_enable_x64 + else jnp.float32, + ) + for k, arr in nn_input_arrays[net_id][input].items() + } + for input in model.nns[net_id].inputs + } model = eqx.tree_at( lambda model: model.nns[net_id].inputs, model, input_array )