diff --git a/epochlib/pipeline/training.py b/epochlib/pipeline/training.py index dabada3..3938f45 100644 --- a/epochlib/pipeline/training.py +++ b/epochlib/pipeline/training.py @@ -400,7 +400,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any: return x - def _set_hash(self, prev_hash: str) -> None: + def set_hash(self, prev_hash: str) -> None: """Set the hash of the pipeline. :param prev_hash: The hash of the previous block. diff --git a/pyproject.toml b/pyproject.toml index 11e74bb..f07446b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "epochlib" -version = "5.0.0" +version = "5.0.1" authors = [ { name = "Jasper van Selm", email = "jmvanselm@gmail.com" }, { name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" }, diff --git a/tests/pipeline/test_training.py b/tests/pipeline/test_training.py index 0b7e886..0eab661 100644 --- a/tests/pipeline/test_training.py +++ b/tests/pipeline/test_training.py @@ -551,10 +551,6 @@ def test_pipeline_get_hash_no_change(self): pred_sys=predicting_system, ) assert x_system.get_hash() == "" - # assert y_system.get_hash() == "" - # assert training_system.get_hash() == "" - # assert predicting_system.get_hash() == "" - # assert pipeline.get_hash() == "" def test_pipeline_get_hash_with_change(self): class TransformingBlock(Transformer): @@ -604,11 +600,18 @@ def transform(self, x): y_system = TransformingSystem() training_system = TrainingSystem() prediction_system = TransformingSystem(steps=[transform1]) - pipeline = Pipeline( + assert x_system.get_hash() == prediction_system.get_hash() + pipeline1 = Pipeline( x_sys=x_system, y_sys=y_system, train_sys=training_system, pred_sys=prediction_system, ) - assert x_system.get_hash() == prediction_system.get_hash() - assert pipeline.get_hash() != "" + pipeline1_train_sys_hash = pipeline1.train_sys.get_hash() + pipeline2 = Pipeline( + x_sys=TransformingSystem(), + y_sys=y_system, + train_sys=training_system, + pred_sys=prediction_system, + ) + assert pipeline1_train_sys_hash != pipeline2.train_sys.get_hash()