diff --git a/python/sdist/amici/gradient_check.py b/python/sdist/amici/gradient_check.py index 116b7c44f5..a9805798e7 100644 --- a/python/sdist/amici/gradient_check.py +++ b/python/sdist/amici/gradient_check.py @@ -66,20 +66,29 @@ def check_finite_difference( finite difference step-size """ + p = copy.deepcopy(x0) + plist = [ip] + + # store original settings and apply new ones og_sensitivity_order = solver.get_sensitivity_order() og_parameters = model.get_parameters() og_plist = model.get_parameter_list() if edata: og_eplist = edata.plist + og_eparameters = edata.parameters - # sensitivity - p = copy.deepcopy(x0) - plist = [ip] + edata.plist = plist + # we always set parameters via the model below + edata.parameters = [] + pscale = ( + edata.pscale if len(edata.pscale) else model.get_parameter_scale() + ) + else: + pscale = model.get_parameter_scale() + model.set_parameter_list(plist) + model.set_parameter_scale(pscale) model.set_parameters(p) - model.set_parameter_list(plist) - if edata: - edata.plist = plist # simulation with gradient if int(og_sensitivity_order) < int(SensitivityOrder.first): @@ -93,8 +102,7 @@ def check_finite_difference( pf = copy.deepcopy(x0) pb = copy.deepcopy(x0) - pscale = model.get_parameter_scale()[ip] - if x0[ip] == 0 or pscale != int(ParameterScaling.none): + if x0[ip] == 0 or pscale[ip] != int(ParameterScaling.none): pf[ip] += epsilon / 2 pb[ip] -= epsilon / 2 else: @@ -142,6 +150,7 @@ def check_finite_difference( model.set_parameter_list(og_plist) if edata: edata.plist = og_eplist + edata.parameters = og_eparameters def check_derivatives( @@ -160,7 +169,8 @@ def check_derivatives( :param model: amici model :param solver: amici solver - :param edata: exp data + :param edata: ExpData instance. If provided, ExpData settings will + override model settings where applicable (`plist`, `parmeters`, ...). :param atol: absolute tolerance for comparison :param rtol: relative tolerance for comparison :param epsilon: finite difference step-size @@ -169,7 +179,10 @@ def check_derivatives( are zero :param skip_fields: list of fields to skip """ - p = np.array(model.get_parameters()) + if edata and edata.parameters: + p = np.array(edata.parameters) + else: + p = np.array(model.get_parameters()) og_sens_order = solver.get_sensitivity_order() diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 0962f57215..bba279249f 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -203,12 +203,6 @@ def check_derivatives( petab_problem=problem, problem_parameters=problem_parameters, ): - # check_derivatives does currently not support parameters in ExpData - # set parameter scales before setting parameter values! - model.set_parameter_scale(edata.pscale) - model.set_parameters(edata.parameters) - edata.parameters = [] - edata.pscale = amici.parameter_scaling_from_int_vector([]) amici_check_derivatives(model, solver, edata)