diff --git a/CHANGELOG.md b/CHANGELOG.md index 836df019ea..bf6e93e047 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,10 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni This is a wrapper for both `amici.run_simulation` and `amici.run_simulations`, depending on the type of the `edata` argument. It also supports passing some `Solver` options as keyword arguments. +* `amici.ModelPtr` now supports sufficient pickling for use in + multi-processing contexts. This works only if the amici-generated model + package exists in the same file system location and does not change until + unpickling. ## v0.X Series diff --git a/python/sdist/amici/swig_wrappers.py b/python/sdist/amici/swig_wrappers.py index 9c1df065d5..4562581e88 100644 --- a/python/sdist/amici/swig_wrappers.py +++ b/python/sdist/amici/swig_wrappers.py @@ -6,6 +6,7 @@ import logging import warnings from collections.abc import Sequence +from pathlib import Path from typing import Any import amici @@ -342,3 +343,57 @@ def _Model__simulate( solver=_get_ptr(solver), edata=_get_ptr(edata), ) + + +def restore_model( + module_name: str, module_path: Path, settings: dict, checksum: str = None +) -> amici.Model: + """ + Recreate a model instance with given settings. + + For use in ModelPtr.__reduce__. + + :param module_name: + Name of the model module. + :param module_path: + Path to the model module. + :param settings: + Model settings to be applied. + See `set_model_settings` / `get_model_settings`. + :param checksum: + Checksum of the model extension to verify integrity. + """ + from . import import_model_module + + model_module = import_model_module(module_name, module_path) + model = model_module.get_model() + model.module = model_module._self + set_model_settings(model, settings) + + if checksum is not None and checksum != file_checksum( + model.module.extension_path + ): + raise RuntimeError( + f"Model file checksum does not match the expected checksum " + f"({checksum}). The model file may have been modified " + f"after the model was pickled." + ) + + return model + + +def file_checksum( + path: str | Path, algorithm: str = "sha256", chunk_size: int = 8192 +) -> str: + """ + Compute checksum for `path` using `algorithm` (e.g. 'md5', 'sha1', 'sha256'). + Returns the hexadecimal digest string. + """ + import hashlib + + path = Path(path) + h = hashlib.new(algorithm) + with path.open("rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + h.update(chunk) + return h.hexdigest() diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index a622abb1ef..db1d731d8f 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -5,12 +5,15 @@ import copy import numbers +import pickle from math import nan import amici import numpy as np import pytest import xarray +from amici import SteadyStateSensitivityMode +from amici.testing import skip_on_valgrind def test_version_number(pysb_example_presimulation_module): @@ -664,3 +667,33 @@ def test_reporting_mode_obs_llh(sbml_example_presimulation_module): assert rdata.ssigmay is None assert rdata.sllh.size > 0 assert not np.isnan(rdata.sllh).any() + + +@skip_on_valgrind +def test_pickle_model(sbml_example_presimulation_module): + model_module = sbml_example_presimulation_module + model = model_module.get_model() + + assert ( + model.get_steady_state_sensitivity_mode() + == SteadyStateSensitivityMode.integrationOnly + ) + model.set_steady_state_sensitivity_mode( + SteadyStateSensitivityMode.newtonOnly + ) + + model_pickled = pickle.loads(pickle.dumps(model)) + # ensure it's re-picklable + model_pickled = pickle.loads(pickle.dumps(model_pickled)) + assert ( + model_pickled.get_steady_state_sensitivity_mode() + == SteadyStateSensitivityMode.newtonOnly + ) + + model_pickled.set_steady_state_sensitivity_mode( + SteadyStateSensitivityMode.integrateIfNewtonFails + ) + assert ( + model.get_steady_state_sensitivity_mode() + != model_pickled.get_steady_state_sensitivity_mode() + ) diff --git a/swig/model.i b/swig/model.i index b3699aaa26..01ccab766c 100644 --- a/swig/model.i +++ b/swig/model.i @@ -195,6 +195,20 @@ def simulate( def __deepcopy__(self, memo): return self.clone() +def __reduce__(self): + from amici.swig_wrappers import restore_model, get_model_settings, file_checksum + + return ( + restore_model, + ( + self.get_name(), + Path(self.module.__spec__.origin).parent, + get_model_settings(self), + file_checksum(self.module.extension_path), + ), + {} + ) + @overload def simulate( diff --git a/swig/modelname.template.i b/swig/modelname.template.i index f0a701e9cc..c3d953c7fa 100644 --- a/swig/modelname.template.i +++ b/swig/modelname.template.i @@ -8,10 +8,11 @@ import sysconfig from pathlib import Path ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') +extension_path = Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}' _TPL_MODELNAME = amici._module_from_path( 'TPL_MODELNAME._TPL_MODELNAME' if __package__ or '.' in __name__ else '_TPL_MODELNAME', - Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}', + extension_path, ) def _get_import_time():