From 4f69ac316ec482efe0b42e7f917dd1a536f9359d Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 21 Jul 2025 13:04:54 +0200 Subject: [PATCH] Event-handling during post-equilibration for FSA * Implement event-handling during post-equilibration for no sensitivities and FSA (#2775) (ASA to be implemented separately) * Additional tests for events * Fix SBML import for parameters that are event targets and have non-time-dependent intial assignments (They previously ended up both in `x` and `w`.) * Drop unused PeriodResult.nroots (meanwhile, this is recomputed from discs) --- include/amici/backwardproblem.h | 4 + include/amici/forwardproblem.h | 4 - include/amici/solver.h | 9 ++ include/amici/sundials_matrix_wrapper.h | 1 - python/sdist/amici/sbml_import.py | 6 +- python/tests/test_events.py | 126 ++++++++++++++++++++---- src/backwardproblem.cpp | 1 + src/forwardproblem.cpp | 25 +++-- src/rdata.cpp | 4 +- src/solver.cpp | 8 ++ 10 files changed, 153 insertions(+), 35 deletions(-) diff --git a/include/amici/backwardproblem.h b/include/amici/backwardproblem.h index e24f4bf9c1..401340669f 100644 --- a/include/amici/backwardproblem.h +++ b/include/amici/backwardproblem.h @@ -212,6 +212,10 @@ class SteadyStateBackwardProblem { /** * @brief Launch backward simulation if Newton solver or linear system solve * fail or are disabled. + * + * This does not perform any event-handling. + * For event-handling, see EventHandlingBwdSimulator. + * * @param solver Solver instance. */ void run_simulation(Solver const& solver); diff --git a/include/amici/forwardproblem.h b/include/amici/forwardproblem.h index dd43ff8d40..9389c68918 100644 --- a/include/amici/forwardproblem.h +++ b/include/amici/forwardproblem.h @@ -179,10 +179,6 @@ struct PeriodResult { /** Discontinuities encountered so far (dimension: dynamic) */ std::vector discs; - /** array of number of found roots for a certain event type - * (dimension: ne) */ - std::vector nroots; - /** simulation states history at timepoints */ std::map timepoint_states_; diff --git a/include/amici/solver.h b/include/amici/solver.h index db00c3d199..6714e61251 100644 --- a/include/amici/solver.h +++ b/include/amici/solver.h @@ -673,6 +673,15 @@ class Solver { */ void writeSolution(SolutionState& sol) const; + /** + * @brief write solution from forward simulation + * @param t Time for which to retrieve the solution + * (interpolated if necessary). Must be greater than or equal to + * the initial timepoint and less than or equal to the current timepoint. + * @param sol solution state + */ + void writeSolution(realtype t, SolutionState& sol) const; + /** * @brief write solution from backward simulation * @param t time diff --git a/include/amici/sundials_matrix_wrapper.h b/include/amici/sundials_matrix_wrapper.h index 26dbb9adee..0686b12eb9 100644 --- a/include/amici/sundials_matrix_wrapper.h +++ b/include/amici/sundials_matrix_wrapper.h @@ -591,7 +591,6 @@ class SUNMatrixWrapper { bool ownmat = true; }; - /** * @brief Output formatter for SUNMatrixWrapper. * @param os output stream diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index eeb8087563..8d4db42e91 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -1728,7 +1728,11 @@ def _convert_event_assignment_parameter_targets_to_species(self): "assignment rules." ) parameter_def = None - for symbol_id in {SymbolId.PARAMETER, SymbolId.FIXED_PARAMETER}: + for symbol_id in { + SymbolId.PARAMETER, + SymbolId.FIXED_PARAMETER, + SymbolId.EXPRESSION, + }: if parameter_target in self.symbols[symbol_id]: # `parameter_target` should only exist in one of the # `symbol_id` dictionaries. diff --git a/python/tests/test_events.py b/python/tests/test_events.py index 87b99d3070..6d1f48c16f 100644 --- a/python/tests/test_events.py +++ b/python/tests/test_events.py @@ -1049,25 +1049,113 @@ def test_event_uses_values_from_trigger_time(tempdir): # generate synthetic measurements edata = amici.ExpData(rdata, 1, 0) - # check forward sensitivities against finite differences - # FIXME: sensitivities w.r.t. the bolus parameter of the first event - # are wrong - model.setParameterList( - [ - ip - for ip, par in enumerate(model.getParameterIds()) - if par not in ["one"] - ] - ) + # check sensitivities against finite differences + + for sens_method in ( + SensitivityMethod.forward, + SensitivityMethod.adjoint, + ): + if sens_method == SensitivityMethod.forward: + # FIXME: forward sensitivities w.r.t. the bolus parameter + # of the first event are wrong + model.setParameterList( + [ + ip + for ip, par in enumerate(model.getParameterIds()) + if par not in ["one"] + ] + ) + elif sens_method == SensitivityMethod.adjoint: + # FIXME: adjoint sensitivities w.r.t. the bolus parameter `three` + # are wrong. + # maybe related to https://github.com/AMICI-dev/AMICI/issues/2805 + model.setParameterList( + [ + ip + for ip, par in enumerate(model.getParameterIds()) + if par not in ["one", "three"] + ] + ) - check_derivatives( - model, - solver, - edata=edata, - atol=1e-6, - rtol=1e-6, - # smaller than the offset from the trigger time - epsilon=1e-8, + solver.setSensitivityMethod(sens_method) + check_derivatives( + model, + solver, + edata=edata, + atol=1e-6, + rtol=1e-6, + # smaller than the offset from the trigger time + epsilon=1e-8, + ) + + +@skip_on_valgrind +def test_posteq_events_are_handled(tempdir): + """Test that events are handled during post-equilibration.""" + from amici.antimony_import import antimony2amici + + model_name = "test_posteq_events_are_handled" + antimony2amici( + r""" + some_time = 0 + some_time' = piecewise(1, (time < 10), 0) + + bolus = 1 + target_initial = 0 + target = target_initial + E1: at time > 1: target = target + bolus + E2: at some_time >= 2: target = target + bolus + """, + observables={ + "obs_target": { + "formula": "target", + } + }, + model_name=model_name, + output_dir=tempdir, + verbose=True, ) - # TODO: test ASA after https://github.com/AMICI-dev/AMICI/pull/1539 + model_module = import_model_module(model_name, tempdir) + model = model_module.get_model() + solver = model.getSolver() + + # test without post-equilibration + model.setTimepoints([10]) + rdata = amici.runAmiciSimulation(model, solver) + assert rdata.status == amici.AMICI_SUCCESS + assert rdata.by_id("target").squeeze() == 2.0 + assert rdata.by_id("obs_target").squeeze() == 2.0 + + # test with post-equilibration + model.setSteadyStateComputationMode( + amici.SteadyStateComputationMode.integrationOnly + ) + model.setSteadyStateSensitivityMode( + amici.SteadyStateSensitivityMode.integrationOnly + ) + model.setTimepoints([np.inf]) + rdata = amici.runAmiciSimulation(model, solver) + assert rdata.status == amici.AMICI_SUCCESS + assert rdata.by_id("target").squeeze() == 2.0 + assert rdata.by_id("obs_target").squeeze() == 2.0 + assert rdata.posteq_t == 10.0 + + # check sensitivities against finite differences + edata = amici.ExpData(rdata, 1, 0, 0) + for sens_method in ( + SensitivityMethod.forward, + # FIXME: sensitivities w.r.t. the bolus parameter are off for ASA (0.0) + # SensitivityMethod.adjoint, + ): + solver.setSensitivityOrder(SensitivityOrder.first) + solver.setSensitivityMethod(sens_method) + check_derivatives( + model, + solver, + edata=edata, + atol=1e-12, + rtol=1e-7, + epsilon=1e-8, + skip_fields=["res"], + ) diff --git a/src/backwardproblem.cpp b/src/backwardproblem.cpp index c1442ae52e..9454ff6d3f 100644 --- a/src/backwardproblem.cpp +++ b/src/backwardproblem.cpp @@ -112,6 +112,7 @@ void BackwardProblem::handlePostequilibration() { } } + // TODO handle any events auto final_state = posteq_problem_->getFinalSimulationState(); posteq_problem_bwd_.emplace(*solver_, *model_, final_state.sol, &ws_); posteq_problem_bwd_->run(model_->t0()); diff --git a/src/forwardproblem.cpp b/src/forwardproblem.cpp index da8639f5d7..b905b3ffa3 100644 --- a/src/forwardproblem.cpp +++ b/src/forwardproblem.cpp @@ -147,7 +147,7 @@ void EventHandlingSimulator::run( fill_events(model_->nMaxEvent(), edata); } - result.nroots = ws_->nroots; + result.final_state_ = {.sol = ws_->sol, .mod = model_->getModelState()}; } void EventHandlingSimulator::run_steady_state( @@ -194,13 +194,20 @@ void EventHandlingSimulator::run_steady_state( // ensure stable computation. // The value is not important for AMICI_ONE_STEP mode, only the // direction w.r.t. current t. - auto status = solver_->step(std::max(ws_->sol.t, 1.0) * 10); - ws_->sol.t = solver_->gett(); + auto tout = std::isfinite(next_t_event) + ? next_t_event + : std::max(ws_->sol.t, 1.0) * 10; + auto status = solver_->step(tout); solver_->writeSolution(ws_->sol); if (status < 0) { throw IntegrationFailure(status, ws_->sol.t); - } else if (status == AMICI_ROOT_RETURN || ws_->sol.t == next_t_event) { + } else if (status == AMICI_ROOT_RETURN || ws_->sol.t >= next_t_event) { + if (ws_->sol.t >= next_t_event) { + // Solver::step will over-step next_t_event + solver_->writeSolution(next_t_event, ws_->sol); + } + // solver-tracked or time-triggered event solver_->getRootInfo(ws_->roots_found.data()); @@ -214,11 +221,16 @@ void EventHandlingSimulator::run_steady_state( ws_->roots_found[ie] = std::copysign(1, -ws_->rootvals[ie]); } ++it_trigger_timepoints; + next_t_event = it_trigger_timepoints != trigger_timepoints.end() + ? *it_trigger_timepoints + : std::numeric_limits::infinity(); } handle_events(false, nullptr); } } + + result.final_state_ = {.sol = ws_->sol, .mod = model_->getModelState()}; } void ForwardProblem::workForwardProblem() { @@ -678,10 +690,7 @@ ForwardProblem::getAdjointUpdates(Model& model, ExpData const& edata) { } SimulationState EventHandlingSimulator::get_simulation_state() { - return { - .sol = ws_->sol, - .mod = model_->getModelState() - }; + return {.sol = ws_->sol, .mod = model_->getModelState()}; } std::vector compute_nroots( diff --git a/src/rdata.cpp b/src/rdata.cpp index 5f85f7c08a..4a7f532307 100644 --- a/src/rdata.cpp +++ b/src/rdata.cpp @@ -259,9 +259,9 @@ void ReturnData::processPostEquilibration( ExpData const* edata ) { for (int it = 0; it < nt; it++) { - auto t = model.getTimepoint(it); + auto const t = model.getTimepoint(it); if (std::isinf(t)) { - auto const simulation_state = posteq.getFinalSimulationState(); + auto const& simulation_state = posteq.getFinalSimulationState(); model.setModelState(simulation_state.mod); getDataOutput(it, model, simulation_state.sol, edata); } diff --git a/src/solver.cpp b/src/solver.cpp index 09854524d5..aebf3a1cba 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -1336,6 +1336,14 @@ void Solver::writeSolution(SolutionState& sol) const { sol.dx.copy(getDerivativeState(sol.t)); } +void Solver::writeSolution(realtype const t, SolutionState& sol) const { + sol.t = t; + if (sens_initialized_) + sol.sx.copy(getStateSensitivity(sol.t)); + sol.x.copy(getState(sol.t)); + sol.dx.copy(getDerivativeState(sol.t)); +} + void Solver::writeSolutionB( realtype& t, AmiVector& xB, AmiVector& dxB, AmiVector& xQB, int const which ) const {