diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a562fd6..c5c1f7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,7 @@ repos: hooks: - id: mypy additional_dependencies: + - annotated-types==0.7.0 - joblib-stubs>=1.4.2.3.20240619 args: - "--fast-module-lookup" diff --git a/agogos/_core.py b/agogos/_core.py index 951b34d..d2d8771 100644 --- a/agogos/_core.py +++ b/agogos/_core.py @@ -1,12 +1,17 @@ """This module contains the core classes for all classes in the agogos package.""" -from abc import abstractmethod +from __future__ import annotations + +from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Annotated, Generic, TypeVar +from annotated_types import MinLen from joblib import hash +_DT = TypeVar("_DT") + @dataclass class _Base: @@ -27,11 +32,13 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ + _parent: _Base | None = field(default=None, init=False) + _children: Iterable[_Base] = field(default_factory=list, init=False) + _hash: str = field(default="", init=False) + def __post_init__(self) -> None: """Initialize the block.""" self._set_hash("") - self._set_parent(None) - self._set_children([]) def _set_hash(self, prev_hash: str) -> None: """Set the hash of the block. @@ -47,14 +54,14 @@ def get_hash(self) -> str: """ return self._hash - def get_parent(self) -> Any: # noqa: ANN401 + def get_parent(self) -> _Base | None: """Get the parent of the block. :return: Parent of the block. """ return self._parent - def get_children(self) -> list[Any]: + def get_children(self) -> Iterable[_Base]: """Get the children of the block. :return: Children of the block @@ -70,14 +77,14 @@ def save_to_html(self, file_path: Path) -> None: with open(file_path, "w") as file: file.write(html) - def _set_parent(self, parent: Any) -> None: # noqa: ANN401 + def _set_parent(self, parent: _Base | None) -> None: """Set the parent of the block. :param parent: Parent of the block. """ self._parent = parent - def _set_children(self, children: list[Any]) -> None: + def _set_children(self, children: Iterable[_Base]) -> None: """Set the children of the block. :param children: Children of the block. @@ -128,7 +135,7 @@ def save_to_html(self, file_path: Path) -> None: @dataclass -class _ParallelSystem(_Base): +class _ParallelSystem(Generic[_DT], _Base): """The _System class is the base class for all systems. Parameters: @@ -154,8 +161,8 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - steps: list[_Base] = field(default_factory=list) - weights: list[float] = field(default_factory=list) + steps: Annotated[Sequence[_Base], MinLen(1)] = field(default_factory=list) + weights: Annotated[Sequence[float], MinLen(1)] = field(default_factory=list) def __post_init__(self) -> None: """Post init function of _System class.""" @@ -170,26 +177,25 @@ def __post_init__(self) -> None: # Set weights if they exist if len(self.weights) == len(self.get_steps()): - [w / sum(self.weights) for w in self.weights] + self.weights = [w / sum(self.weights) for w in self.weights] else: num_steps = len(self.get_steps()) - self.weights = [1 / num_steps for x in self.steps] + self.weights = [1 / num_steps] * num_steps self._set_children(self.steps) - def get_steps(self) -> list[_Base]: + def get_steps(self) -> Annotated[Sequence[_Base], MinLen(1)]: """Return list of steps of _ParallelSystem. :return: List of steps. """ - if not self.steps: - return [] return self.steps - def get_weights(self) -> list[float]: + def get_weights(self) -> Annotated[Sequence[float], MinLen(1)]: """Return list of weights of _ParallelSystem. :return: List of weights. + :raises TypeError: If mismatch between weights and steps. """ if len(self.get_steps()) != len(self.weights): raise TypeError("Mismatch between weights and steps") @@ -221,8 +227,7 @@ def _set_hash(self, prev_hash: str) -> None: self._hash = hash(total) - @abstractmethod - def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401 + def concat(self, original_data: _DT | None, data_to_concat: _DT, weight: float = 1.0) -> _DT: """Concatenate the transformed data. :param original_data: The first input data. @@ -255,7 +260,7 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - steps: list[_Base] = field(default_factory=list) + steps: Annotated[Sequence[_Base], MinLen(1)] = field(default_factory=list) def __post_init__(self) -> None: """Post init function of _System class.""" @@ -267,7 +272,7 @@ def __post_init__(self) -> None: self._set_children(self.steps) - def get_steps(self) -> list[_Base]: + def get_steps(self) -> Sequence[_Base]: """Return list of steps of _ParallelSystem. :return: List of steps. diff --git a/agogos/training.py b/agogos/training.py index cb23dbe..1c64e7d 100644 --- a/agogos/training.py +++ b/agogos/training.py @@ -2,39 +2,43 @@ import copy import warnings -from abc import abstractmethod from dataclasses import dataclass -from typing import Any +from typing import Any, Generic, TypeVar, cast from joblib import hash from agogos._core import _Base, _Block, _ParallelSystem, _SequentialSystem from agogos.transforming import TransformingSystem +_XT = TypeVar("_XT") +_YT = TypeVar("_YT") +_PT = TypeVar("_PT") -class TrainType(_Base): + +class TrainType(Generic[_XT, _YT, _PT], _Base): """Abstract train type describing a class that implements two functions train and predict.""" - @abstractmethod - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the block. :param x: The input data. :param y: The target variable. + :param train_args: The arguments to pass to the training system. + :return: The predictions and the target variable. """ raise NotImplementedError(f"{self.__class__.__name__} does not implement train method.") - @abstractmethod - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the target variable. :param x: The input data. + :param pred_args: The arguments to pass to the prediction system. :return: The predictions. """ raise NotImplementedError(f"{self.__class__.__name__} does not implement predict method.") -class Trainer(TrainType, _Block): +class Trainer(TrainType[_XT, _YT, _PT], _Block): """The trainer block is for blocks that need to train on two inputs and predict on one. Methods: @@ -80,7 +84,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any: """ -class TrainingSystem(TrainType, _SequentialSystem): +class TrainingSystem(TrainType[_XT, _YT, _PT], _SequentialSystem): """A system that trains on the input data and labels. Parameters: @@ -118,17 +122,17 @@ def save_to_html(self, file_path: Path) -> None: def __post_init__(self) -> None: """Post init method for the TrainingSystem class.""" - # Assert all steps are a subclass of Trainer + # Assert all steps are a subclass of TrainType for step in self.steps: if not isinstance( step, - (TrainType), + TrainType, ): raise TypeError(f"step: {step} is not an instance of TrainType") super().__post_init__() - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the system. :param x: The input to the system. @@ -153,14 +157,14 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: step_name = step.__class__.__name__ step_args = train_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): x, y = step.train(x, y, **step_args) else: raise TypeError(f"{step} is not an instance of TrainType") return x, y - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the output of the system. :param x: The input to the system. @@ -185,7 +189,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 step_args = pred_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): x = step.predict(x, **step_args) else: raise TypeError(f"{step} is not an instance of TrainType") @@ -193,7 +197,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 return x -class ParallelTrainingSystem(TrainType, _ParallelSystem): +class ParallelTrainingSystem(TrainType[_XT, _YT, _PT], _ParallelSystem[_PT]): """A system that trains the input data in parallel. Parameters: @@ -236,12 +240,12 @@ def __post_init__(self) -> None: """Post init method for the ParallelTrainingSystem class.""" # Assert all steps correct instances for step in self.steps: - if not isinstance(step, (TrainType)): + if not isinstance(step, TrainType): raise TypeError(f"{step} is not an instance of TrainType") super().__post_init__() - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_PT, _YT]: """Train the system. :param x: The input to the system. @@ -255,7 +259,7 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: step_args = train_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): step_x, step_y = step.train(copy.deepcopy(x), copy.deepcopy(y), **step_args) out_x, out_y = ( self.concat(out_x, step_x, self.get_weights()[i]), @@ -264,9 +268,12 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: else: raise TypeError(f"{step} is not an instance of TrainType") + if out_x is None or out_y is None: + raise ValueError("No steps were executed in the training system.") + return out_x, out_y - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _PT: """Predict the output of the system. :param x: The input to the system. @@ -279,15 +286,18 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 step_args = pred_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): step_x = step.predict(copy.deepcopy(x), **step_args) out_x = self.concat(out_x, step_x, self.get_weights()[i]) else: raise TypeError(f"{step} is not an instance of TrainType") + if out_x is None: + raise ValueError("Predictions output is None.") + return out_x - def concat_labels(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401 + def concat_labels(self, original_data: _YT | None, data_to_concat: _YT, weight: float = 1.0) -> _YT: """Concatenate the transformed labels. Will use concat method if not overridden. :param original_data: The first input data. @@ -295,11 +305,11 @@ def concat_labels(self, original_data: Any, data_to_concat: Any, weight: float = :param weight: Weight of data to concat. :return: The concatenated data. """ - return self.concat(original_data, data_to_concat, weight) + return cast(_YT, self.concat(cast(_PT | None, original_data), cast(_PT, data_to_concat), weight)) @dataclass -class Pipeline(TrainType): +class Pipeline(TrainType[_XT, _YT, _PT]): """A pipeline of systems that can be trained and predicted. Parameters: @@ -344,11 +354,11 @@ def save_to_html(self, file_path: Path) -> None: predictions = pipeline.predict(x) """ - x_sys: TransformingSystem | None = None - y_sys: TransformingSystem | None = None - train_sys: Trainer | TrainingSystem | ParallelTrainingSystem | None = None - pred_sys: TransformingSystem | None = None - label_sys: TransformingSystem | None = None + x_sys: TransformingSystem[_XT, _XT] | None = None + y_sys: TransformingSystem[_YT, _YT] | None = None + train_sys: Trainer[_XT, _YT, _PT] | TrainingSystem[_XT, _YT, _PT] | ParallelTrainingSystem[_XT, _YT, _PT] | None = None + pred_sys: TransformingSystem[_XT | _PT, _XT | _PT] | None = None + label_sys: TransformingSystem[_YT, _YT] | None = None def __post_init__(self) -> None: """Post initialization function of the Pipeline.""" @@ -371,7 +381,7 @@ def __post_init__(self) -> None: self._set_children(children) - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the system. :param x: The input to the system. @@ -383,16 +393,18 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: x = self.x_sys.transform(x, **train_args.get("x_sys", {})) if self.y_sys is not None: y = self.y_sys.transform(y, **train_args.get("y_sys", {})) + + predictions: _XT | _PT = x if self.train_sys is not None: - x, y = self.train_sys.train(x, y, **train_args.get("train_sys", {})) + predictions, y = self.train_sys.train(x, y, **train_args.get("train_sys", {})) if self.pred_sys is not None: - x = self.pred_sys.transform(x, **train_args.get("pred_sys", {})) + predictions = self.pred_sys.transform(predictions, **train_args.get("pred_sys", {})) if self.label_sys is not None: y = self.label_sys.transform(y, **train_args.get("label_sys", {})) - return x, y + return predictions, y - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the output of the system. :param x: The input to the system. @@ -401,12 +413,14 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 """ if self.x_sys is not None: x = self.x_sys.transform(x, **pred_args.get("x_sys", {})) + + predictions: _XT | _PT = x if self.train_sys is not None: - x = self.train_sys.predict(x, **pred_args.get("train_sys", {})) + predictions = self.train_sys.predict(x, **pred_args.get("train_sys", {})) if self.pred_sys is not None: - x = self.pred_sys.transform(x, **pred_args.get("pred_sys", {})) + predictions = self.pred_sys.transform(predictions, **pred_args.get("pred_sys", {})) - return x + return predictions def _set_hash(self, prev_hash: str) -> None: """Set the hash of the pipeline. diff --git a/agogos/transforming.py b/agogos/transforming.py index 12c8f2b..ba416c0 100644 --- a/agogos/transforming.py +++ b/agogos/transforming.py @@ -2,17 +2,18 @@ import copy import warnings -from abc import abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar from agogos._core import _Base, _Block, _ParallelSystem, _SequentialSystem +_T = TypeVar("_T") +_R = TypeVar("_R") -class TransformType(_Base): + +class TransformType(Generic[_T, _R], _Base): """Abstract transform type describing a class that implements the transform function.""" - @abstractmethod - def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 + def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. :param data: The input data. @@ -22,7 +23,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 raise NotImplementedError(f"{self.__class__.__name__} does not implement transform method.") -class Transformer(TransformType, _Block): +class Transformer(TransformType[_T, _R], _Block): """The transformer block transforms any data it could be x or y data. Methods: @@ -59,7 +60,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: """ -class TransformingSystem(TransformType, _SequentialSystem): +class TransformingSystem(TransformType[_T, _R], _SequentialSystem): """A system that transforms the input data. Parameters: @@ -99,12 +100,12 @@ def __post_init__(self) -> None: """Post init method for the TransformingSystem class.""" # Assert all steps are a subclass of Transformer for step in self.steps: - if not isinstance(step, (TransformType)): + if not isinstance(step, TransformType): raise TypeError(f"{step} is not an instance of TransformType") super().__post_init__() - def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 + def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. :param data: The input data. @@ -123,7 +124,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 step_name = step.__class__.__name__ step_args = transform_args.get(step_name, {}) - if isinstance(step, (TransformType)): + if isinstance(step, TransformType): data = step.transform(data, **step_args) else: raise TypeError(f"{step} is not an instance of TransformType") @@ -131,7 +132,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 return data -class ParallelTransformingSystem(TransformType, _ParallelSystem): +class ParallelTransformingSystem(TransformType[_T, _R], _ParallelSystem[_R]): """A system that transforms the input data in parallel. Parameters: @@ -179,31 +180,34 @@ def __post_init__(self) -> None: """Post init method for the ParallelTransformingSystem class.""" # Assert all steps are a subclass of Transformer or TransformingSystem for step in self.steps: - if not isinstance(step, (TransformType)): + if not isinstance(step, TransformType): raise TypeError(f"{step} is not an instance of TransformType") super().__post_init__() - def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 + def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. :param data: The input data. + :param transform_args: Additional arguments. :return: The transformed data. """ - # Loop through each step and call the transform method - out_data = None - if len(self.get_steps()) == 0: + if len(self.get_steps()) < 1: return data + out_data = None for i, step in enumerate(self.get_steps()): step_name = step.__class__.__name__ step_args = transform_args.get(step_name, {}) - if isinstance(step, (TransformType)): + if isinstance(step, TransformType): step_data = step.transform(copy.deepcopy(data), **step_args) out_data = self.concat(out_data, step_data, self.get_weights()[i]) else: raise TypeError(f"{step} is not an instance of TransformType") + if out_data is None: + raise ValueError("Transformation output is None.") + return out_data diff --git a/pyproject.toml b/pyproject.toml index 0bad43b..7da5d3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ classifiers = [ "Intended Audience :: Developers" ] dependencies= [ - "joblib>=1.4.0" + "joblib>=1.4.0", + "annotated-types>=0.7.0", ] [project.urls] diff --git a/requirements-dev.lock b/requirements-dev.lock index 2f867d3..7b3c50e 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -12,6 +12,8 @@ -e file:. alabaster==0.7.16 # via sphinx +annotated-types==0.7.0 + # via agogos babel==2.15.0 # via sphinx beautifulsoup4==4.12.3 diff --git a/requirements.lock b/requirements.lock index e2e90ba..bfa1e3e 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,5 +10,7 @@ --index-url https://pypi.org/simple/ -e file:. +annotated-types==0.7.0 + # via agogos joblib==1.4.2 # via agogos diff --git a/tests/test__core.py b/tests/test__core.py index c59470b..50ef687 100644 --- a/tests/test__core.py +++ b/tests/test__core.py @@ -9,7 +9,7 @@ def test_init(self): assert base is not None def test_get_hash(self): - assert _Base().get_hash() == "be5a33685928d3da88062f187295a019" + assert _Base().get_hash() == "e75c7852dd3d36ffc2f1b90efa9568d8" def test_set_hash(self): base = _Base() @@ -39,7 +39,7 @@ def test__repr_html_(self): base = _Base() assert ( base._repr_html_() - == "

Class: _Base

" + == "

Class: _Base

" ) def test_save_to_html(self): @@ -60,17 +60,17 @@ def test_block_set_hash(self): block = _Block() block._set_hash("") hash1 = block.get_hash() - assert hash1 == "04714d9ee40c9baff8c528ed982a103c" + assert hash1 == "8c52898e95f367f12e9079cc62d141cb" block._set_hash(hash1) hash2 = block.get_hash() - assert hash2 == "83196595c42f8eff9218c0ac8f80faf0" + assert hash2 == "243051d47d23a36250057824eed90525" assert hash1 != hash2 def test_block_get_hash(self): block = _Block() block._set_hash("") hash1 = block.get_hash() - assert hash1 == "04714d9ee40c9baff8c528ed982a103c" + assert hash1 == "8c52898e95f367f12e9079cc62d141cb" def test__repr_html_(self): block_instance = _Block() @@ -93,7 +93,7 @@ def test_system_hash_with_1_step(self): block1 = _Block() system = _SequentialSystem([block1]) - assert system.get_hash() == "04714d9ee40c9baff8c528ed982a103c" + assert system.get_hash() == "8c52898e95f367f12e9079cc62d141cb" assert block1.get_hash() == system.get_hash() def test_system_hash_with_2_steps(self): @@ -103,7 +103,7 @@ def test_system_hash_with_2_steps(self): system = _SequentialSystem([block1, block2]) assert system.get_hash() != block1.get_hash() assert ( - system.get_hash() == block2.get_hash() == "83196595c42f8eff9218c0ac8f80faf0" + system.get_hash() == block2.get_hash() == "a60a4a1d474b8454b0ee9197875171f6" ) def test_system_hash_with_3_steps(self): @@ -116,7 +116,7 @@ def test_system_hash_with_3_steps(self): assert system.get_hash() != block2.get_hash() assert block1.get_hash() != block2.get_hash() assert ( - system.get_hash() == block3.get_hash() == "5aaa5f0962baedf36f132ad39380761e" + system.get_hash() == block3.get_hash() == "dfebfe0e709805bd8bc328d115400929" ) def test__repr_html_(self): @@ -140,7 +140,7 @@ def test_parallel_system_hash_with_1_step(self): block1 = _Block() system = _ParallelSystem([block1]) - assert system.get_hash() == "04714d9ee40c9baff8c528ed982a103c" + assert system.get_hash() == "8c52898e95f367f12e9079cc62d141cb" assert block1.get_hash() == system.get_hash() def test_parallel_system_hash_with_2_steps(self): @@ -151,7 +151,7 @@ def test_parallel_system_hash_with_2_steps(self): assert system.get_hash() != block1.get_hash() assert block1.get_hash() == block2.get_hash() assert system.get_hash() != block2.get_hash() - assert system.get_hash() == "9689e0f292013df811f8e910684406f7" + assert system.get_hash() == "6430f15bf7782822914896704610d45d" def test_parallel_system_hash_with_3_steps(self): block1 = _Block() @@ -163,7 +163,7 @@ def test_parallel_system_hash_with_3_steps(self): assert system.get_hash() != block2.get_hash() assert system.get_hash() != block3.get_hash() assert block1.get_hash() == block2.get_hash() == block3.get_hash() - assert system.get_hash() == "b5ea75f99dbfb82c35e082c88b94bda7" + assert system.get_hash() == "e4e48c6863f83ed0e876357da5a6ed24" def test_parallel_system__repr_html_(self): block_instance = _Block() diff --git a/tests/test_training.py b/tests/test_training.py index 7f3e154..abf6b29 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -34,7 +34,7 @@ def predict(self, x): def test_trainer_hash(self): trainer = Trainer() - assert trainer.get_hash() == "0a1fcf1d677d4a1f3f082aa85ffcb684" + assert trainer.get_hash() == "7e416fe4f426f9b52e95a62a93972fa7" class TestTrainingSystem: @@ -569,7 +569,7 @@ def transform(self, x): pred_sys=prediction_system, ) assert x_system.get_hash() != y_system.get_hash() - assert pipeline.get_hash() == "787438b6940d465d122113444249eaa4" + assert pipeline.get_hash() == "dcfb09379dbd5f2b1ba6fd9430e2558a" def test_pipeline_predict_system_hash(self): class TransformingBlock(Transformer): @@ -588,7 +588,7 @@ def transform(self, x): pred_sys=prediction_system, ) assert prediction_system.get_hash() != x_system.get_hash() - assert pipeline.get_hash() == "842e1162d744e7ab09c941300a43c218" + assert pipeline.get_hash() == "df836b9fd09c38916f256b43c368caa7" def test_pipeline_pre_post_hash(self): class TransformingBlock(Transformer): @@ -607,4 +607,4 @@ def transform(self, x): pred_sys=prediction_system, ) assert x_system.get_hash() != prediction_system.get_hash() - assert pipeline.get_hash() == "8a0be6742040f6a05d805ac79b486f6c" + assert pipeline.get_hash() == "7ea1ca20dd1478a47bc44c2427a09520" diff --git a/tests/test_transforming.py b/tests/test_transforming.py index 5cf59cc..0bc96d5 100644 --- a/tests/test_transforming.py +++ b/tests/test_transforming.py @@ -28,7 +28,7 @@ def transform(self, data): def test_transformer_hash(self): transformer = Transformer() - assert transformer.get_hash() == "1cbcc4f2d0921b050d9b719d2beb6529" + assert transformer.get_hash() == "6afa3259ad9293ce2c3ff6ca58ab8d68" class TestTransformingSystem: @@ -300,12 +300,12 @@ def transform(self, data): system.transform([1, 2, 3]) def test_transform_parallel_hashes(self): - class SubTransformer1(Transformer): - def transform(self, x): + class SubTransformer1(Transformer[int, int]): + def transform(self, x: int, **kwargs) -> int: return x - class SubTransformer2(Transformer): - def transform(self, x): + class SubTransformer2(Transformer[int, int]): + def transform(self, x: int, **kwargs) -> int: return x * 2 block1 = SubTransformer1()