diff --git a/CHANGELOG.md b/CHANGELOG.md index bf6e93e047..5aef0134dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni multi-processing contexts. This works only if the amici-generated model package exists in the same file system location and does not change until unpickling. +* `amici.ExpData` is now picklable. ## v0.X Series diff --git a/python/sdist/amici/swig_wrappers.py b/python/sdist/amici/swig_wrappers.py index 4562581e88..53092fb806 100644 --- a/python/sdist/amici/swig_wrappers.py +++ b/python/sdist/amici/swig_wrappers.py @@ -397,3 +397,26 @@ def file_checksum( for chunk in iter(lambda: f.read(chunk_size), b""): h.update(chunk) return h.hexdigest() + + +def restore_edata( + init_args: Sequence, + simulation_parameter_dict: dict[str, Any], +) -> amici_swig.ExpData: + """ + Recreate an ExpData instance. + + For use in ExpData.__reduce__. + """ + edata = amici_swig.ExpData(*init_args) + + edata.pscale = amici.parameter_scaling_from_int_vector( + simulation_parameter_dict.pop("pscale") + ) + for key, value in simulation_parameter_dict.items(): + if key == "timepoints": + # timepoints are set during ExpData construction + continue + assert hasattr(edata, key) + setattr(edata, key, value) + return edata diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index db1d731d8f..783fed7d9d 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -697,3 +697,18 @@ def test_pickle_model(sbml_example_presimulation_module): model.get_steady_state_sensitivity_mode() != model_pickled.get_steady_state_sensitivity_mode() ) + + +def test_pickle_edata(): + ny = 2 + nz = 3 + ne = 4 + nt = 5 + edata = amici.ExpData(ny, nz, ne, range(nt)) + edata.set_observed_data(list(np.arange(ny * nt, dtype=float))) + edata.pscale = amici.parameter_scaling_from_int_vector( + [amici.ParameterScaling.log10] * 5 + ) + + edata_pickled = pickle.loads(pickle.dumps(edata)) + assert edata == edata_pickled diff --git a/swig/amici.i b/swig/amici.i index 7b7929c078..2acf550c97 100644 --- a/swig/amici.i +++ b/swig/amici.i @@ -155,6 +155,23 @@ wrap_unique_ptr(ExpDataPtr, amici::ExpData) %naturalvar amici::SimulationParameters::reinitialization_state_idxs_sim; %naturalvar amici::SimulationParameters::reinitialization_state_idxs_presim; +%extend amici::SimulationParameters { +%pythoncode %{ + def __iter__(self): + for attr_name in dir(self): + if ( + not attr_name.startswith('_') + and attr_name not in ("this", "thisown") + and not callable(attr_val := getattr(self, attr_name)) + ): + if isinstance(attr_val, (DoubleVector, ParameterScalingVector)): + yield attr_name, tuple(attr_val) + else: + yield attr_name, attr_val +%} +} + + // DO NOT IGNORE amici::SimulationParameters, amici::ModelDimensions, amici::CpuTimer %ignore amici::ModelContext; %ignore amici::ContextManager; diff --git a/swig/edata.i b/swig/edata.i index 276d1bc911..257b4cefd4 100644 --- a/swig/edata.i +++ b/swig/edata.i @@ -102,6 +102,27 @@ def __deepcopy__(self, memo): # invoke copy constructor return type(self)(self) +def __reduce__(self): + from amici.swig_wrappers import restore_edata + + return ( + restore_edata, + ( + # ExpData ctor arguments + ( + self.nytrue(), + self.nztrue(), + self.nmaxevent(), + self.get_timepoints(), + self.get_observed_data(), + self.get_observed_data_std_dev(), + self.get_observed_events(), + self.get_observed_events_std_dev(), + ), + dict(self) + ), + {} + ) %} }; %extend std::unique_ptr {