From c6b3446a204d1a17802091d405a65d7a25d15a2e Mon Sep 17 00:00:00 2001 From: Jeffrey Lim Date: Wed, 26 Jun 2024 15:19:50 +0200 Subject: [PATCH 1/4] Refactor _core.py --- agogos/_core.py | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/agogos/_core.py b/agogos/_core.py index 951b34d..40eb5b9 100644 --- a/agogos/_core.py +++ b/agogos/_core.py @@ -1,15 +1,20 @@ """This module contains the core classes for all classes in the agogos package.""" - -from abc import abstractmethod +import numbers +from abc import abstractmethod, ABC +from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, TypeVar, Generic from joblib import hash +_ParentT = TypeVar("_ParentT", bound="_Base") +_ChildT = TypeVar("_ChildT", bound="_Base") +_DT = TypeVar("_DT") -@dataclass -class _Base: + +@dataclass(slots=True) +class _Base(Generic[_ParentT, _ChildT]): """The _Base class is the base class for all classes in the agogos package. Methods: @@ -26,12 +31,13 @@ def get_children(self) -> list[Any]: def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ + _parent: _ParentT | None = None + _children: Iterable[_ChildT] = field(default_factory=list) + _hash: str = field(init=False) def __post_init__(self) -> None: """Initialize the block.""" self._set_hash("") - self._set_parent(None) - self._set_children([]) def _set_hash(self, prev_hash: str) -> None: """Set the hash of the block. @@ -47,14 +53,14 @@ def get_hash(self) -> str: """ return self._hash - def get_parent(self) -> Any: # noqa: ANN401 + def get_parent(self) -> _ParentT: # noqa: ANN401 """Get the parent of the block. :return: Parent of the block. """ return self._parent - def get_children(self) -> list[Any]: + def get_children(self) -> Iterable[_ChildT]: """Get the children of the block. :return: Children of the block @@ -70,14 +76,14 @@ def save_to_html(self, file_path: Path) -> None: with open(file_path, "w") as file: file.write(html) - def _set_parent(self, parent: Any) -> None: # noqa: ANN401 + def _set_parent(self, parent: _ParentT | None) -> None: # noqa: ANN401 """Set the parent of the block. :param parent: Parent of the block. """ self._parent = parent - def _set_children(self, children: list[Any]) -> None: + def _set_children(self, children: Iterable[_ChildT]) -> None: """Set the children of the block. :param children: Children of the block. @@ -127,8 +133,8 @@ def save_to_html(self, file_path: Path) -> None: """ -@dataclass -class _ParallelSystem(_Base): +@dataclass(slots=True) +class _ParallelSystem(ABC, Generic[_DT], _Base): """The _System class is the base class for all systems. Parameters: @@ -154,8 +160,8 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - steps: list[_Base] = field(default_factory=list) - weights: list[float] = field(default_factory=list) + steps: Sequence[_ChildT] = field(default_factory=list) + weights: Sequence[numbers.Real] = field(default_factory=list) def __post_init__(self) -> None: """Post init function of _System class.""" @@ -177,16 +183,14 @@ def __post_init__(self) -> None: self._set_children(self.steps) - def get_steps(self) -> list[_Base]: + def get_steps(self) -> Sequence[_ChildT]: """Return list of steps of _ParallelSystem. :return: List of steps. """ - if not self.steps: - return [] return self.steps - def get_weights(self) -> list[float]: + def get_weights(self) -> Sequence[numbers.Real]: """Return list of weights of _ParallelSystem. :return: List of weights. @@ -222,7 +226,7 @@ def _set_hash(self, prev_hash: str) -> None: self._hash = hash(total) @abstractmethod - def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401 + def concat(self, original_data: _DT, data_to_concat: _DT, weight: numbers.Real = 1.0) -> _DT: # noqa: ANN401 """Concatenate the transformed data. :param original_data: The first input data. @@ -255,7 +259,7 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - steps: list[_Base] = field(default_factory=list) + steps: Sequence[_ChildT] = field(default_factory=list) def __post_init__(self) -> None: """Post init function of _System class.""" @@ -267,7 +271,7 @@ def __post_init__(self) -> None: self._set_children(self.steps) - def get_steps(self) -> list[_Base]: + def get_steps(self) -> Sequence[_ChildT]: """Return list of steps of _ParallelSystem. :return: List of steps. From 983054f9bec4b84af2aa0689aeb0190dc0da8212 Mon Sep 17 00:00:00 2001 From: Jeffrey Lim Date: Fri, 28 Jun 2024 11:11:57 +0200 Subject: [PATCH 2/4] Use generics instead of Any --- .pre-commit-config.yaml | 1 + agogos/_core.py | 49 ++++++++++++----------- agogos/training.py | 88 +++++++++++++++++++++++++---------------- agogos/transforming.py | 36 ++++++++++------- pyproject.toml | 3 +- requirements-dev.lock | 2 + requirements.lock | 2 + 7 files changed, 108 insertions(+), 73 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a562fd6..c5c1f7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,7 @@ repos: hooks: - id: mypy additional_dependencies: + - annotated-types==0.7.0 - joblib-stubs>=1.4.2.3.20240619 args: - "--fast-module-lookup" diff --git a/agogos/_core.py b/agogos/_core.py index 40eb5b9..a355345 100644 --- a/agogos/_core.py +++ b/agogos/_core.py @@ -1,20 +1,21 @@ """This module contains the core classes for all classes in the agogos package.""" -import numbers -from abc import abstractmethod, ABC + +from __future__ import annotations + +from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from pathlib import Path -from typing import Any, TypeVar, Generic +from typing import Annotated, Generic, TypeVar +from annotated_types import MinLen from joblib import hash -_ParentT = TypeVar("_ParentT", bound="_Base") -_ChildT = TypeVar("_ChildT", bound="_Base") _DT = TypeVar("_DT") -@dataclass(slots=True) -class _Base(Generic[_ParentT, _ChildT]): +@dataclass +class _Base: """The _Base class is the base class for all classes in the agogos package. Methods: @@ -31,8 +32,9 @@ def get_children(self) -> list[Any]: def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - _parent: _ParentT | None = None - _children: Iterable[_ChildT] = field(default_factory=list) + + _parent: _Base | None = None + _children: Iterable[_Base] = field(default_factory=list) _hash: str = field(init=False) def __post_init__(self) -> None: @@ -53,14 +55,14 @@ def get_hash(self) -> str: """ return self._hash - def get_parent(self) -> _ParentT: # noqa: ANN401 + def get_parent(self) -> _Base | None: """Get the parent of the block. :return: Parent of the block. """ return self._parent - def get_children(self) -> Iterable[_ChildT]: + def get_children(self) -> Iterable[_Base]: """Get the children of the block. :return: Children of the block @@ -76,14 +78,14 @@ def save_to_html(self, file_path: Path) -> None: with open(file_path, "w") as file: file.write(html) - def _set_parent(self, parent: _ParentT | None) -> None: # noqa: ANN401 + def _set_parent(self, parent: _Base | None) -> None: """Set the parent of the block. :param parent: Parent of the block. """ self._parent = parent - def _set_children(self, children: Iterable[_ChildT]) -> None: + def _set_children(self, children: Iterable[_Base]) -> None: """Set the children of the block. :param children: Children of the block. @@ -133,7 +135,7 @@ def save_to_html(self, file_path: Path) -> None: """ -@dataclass(slots=True) +@dataclass class _ParallelSystem(ABC, Generic[_DT], _Base): """The _System class is the base class for all systems. @@ -160,8 +162,8 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - steps: Sequence[_ChildT] = field(default_factory=list) - weights: Sequence[numbers.Real] = field(default_factory=list) + steps: Annotated[Sequence[_Base], MinLen(1)] = field(default_factory=list) + weights: Annotated[Sequence[float], MinLen(1)] = field(default_factory=list) def __post_init__(self) -> None: """Post init function of _System class.""" @@ -176,24 +178,25 @@ def __post_init__(self) -> None: # Set weights if they exist if len(self.weights) == len(self.get_steps()): - [w / sum(self.weights) for w in self.weights] + self.weights = [w / sum(self.weights) for w in self.weights] else: num_steps = len(self.get_steps()) - self.weights = [1 / num_steps for x in self.steps] + self.weights = [1 / num_steps] * num_steps self._set_children(self.steps) - def get_steps(self) -> Sequence[_ChildT]: + def get_steps(self) -> Annotated[Sequence[_Base], MinLen(1)]: """Return list of steps of _ParallelSystem. :return: List of steps. """ return self.steps - def get_weights(self) -> Sequence[numbers.Real]: + def get_weights(self) -> Annotated[Sequence[float], MinLen(1)]: """Return list of weights of _ParallelSystem. :return: List of weights. + :raises TypeError: If mismatch between weights and steps. """ if len(self.get_steps()) != len(self.weights): raise TypeError("Mismatch between weights and steps") @@ -226,7 +229,7 @@ def _set_hash(self, prev_hash: str) -> None: self._hash = hash(total) @abstractmethod - def concat(self, original_data: _DT, data_to_concat: _DT, weight: numbers.Real = 1.0) -> _DT: # noqa: ANN401 + def concat(self, original_data: _DT | None, data_to_concat: _DT, weight: float = 1.0) -> _DT: """Concatenate the transformed data. :param original_data: The first input data. @@ -259,7 +262,7 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - steps: Sequence[_ChildT] = field(default_factory=list) + steps: Annotated[Sequence[_Base], MinLen(1)] = field(default_factory=list) def __post_init__(self) -> None: """Post init function of _System class.""" @@ -271,7 +274,7 @@ def __post_init__(self) -> None: self._set_children(self.steps) - def get_steps(self) -> Sequence[_ChildT]: + def get_steps(self) -> Sequence[_Base]: """Return list of steps of _ParallelSystem. :return: List of steps. diff --git a/agogos/training.py b/agogos/training.py index cb23dbe..1e747cd 100644 --- a/agogos/training.py +++ b/agogos/training.py @@ -2,39 +2,48 @@ import copy import warnings -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any +from typing import Any, Generic, TypeVar, cast from joblib import hash from agogos._core import _Base, _Block, _ParallelSystem, _SequentialSystem from agogos.transforming import TransformingSystem +_XT = TypeVar("_XT") +_YT = TypeVar("_YT") +_PT = TypeVar("_PT") -class TrainType(_Base): + +@dataclass +class TrainType(ABC, Generic[_XT, _YT, _PT], _Base): """Abstract train type describing a class that implements two functions train and predict.""" @abstractmethod - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the block. :param x: The input data. :param y: The target variable. + :param train_args: The arguments to pass to the training system. + :return: The predictions and the target variable. """ raise NotImplementedError(f"{self.__class__.__name__} does not implement train method.") @abstractmethod - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the target variable. :param x: The input data. + :param pred_args: The arguments to pass to the prediction system. :return: The predictions. """ raise NotImplementedError(f"{self.__class__.__name__} does not implement predict method.") -class Trainer(TrainType, _Block): +@dataclass +class Trainer(TrainType[_XT, _YT, _PT], _Block): """The trainer block is for blocks that need to train on two inputs and predict on one. Methods: @@ -80,7 +89,8 @@ def predict(self, x: Any, **pred_args: Any) -> Any: """ -class TrainingSystem(TrainType, _SequentialSystem): +@dataclass +class TrainingSystem(TrainType[_XT, _YT, _PT], _SequentialSystem): """A system that trains on the input data and labels. Parameters: @@ -118,17 +128,17 @@ def save_to_html(self, file_path: Path) -> None: def __post_init__(self) -> None: """Post init method for the TrainingSystem class.""" - # Assert all steps are a subclass of Trainer + # Assert all steps are a subclass of TrainType for step in self.steps: if not isinstance( step, - (TrainType), + TrainType, ): raise TypeError(f"step: {step} is not an instance of TrainType") super().__post_init__() - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the system. :param x: The input to the system. @@ -153,14 +163,14 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: step_name = step.__class__.__name__ step_args = train_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): x, y = step.train(x, y, **step_args) else: raise TypeError(f"{step} is not an instance of TrainType") return x, y - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the output of the system. :param x: The input to the system. @@ -185,7 +195,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 step_args = pred_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): x = step.predict(x, **step_args) else: raise TypeError(f"{step} is not an instance of TrainType") @@ -193,7 +203,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 return x -class ParallelTrainingSystem(TrainType, _ParallelSystem): +class ParallelTrainingSystem(TrainType[_XT, _YT, _PT], _ParallelSystem[_PT]): """A system that trains the input data in parallel. Parameters: @@ -236,12 +246,12 @@ def __post_init__(self) -> None: """Post init method for the ParallelTrainingSystem class.""" # Assert all steps correct instances for step in self.steps: - if not isinstance(step, (TrainType)): + if not isinstance(step, TrainType): raise TypeError(f"{step} is not an instance of TrainType") super().__post_init__() - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_PT, _YT]: """Train the system. :param x: The input to the system. @@ -264,9 +274,12 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: else: raise TypeError(f"{step} is not an instance of TrainType") + if out_x is None or out_y is None: + raise ValueError("No steps were executed in the training system.") + return out_x, out_y - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _PT: """Predict the output of the system. :param x: The input to the system. @@ -279,15 +292,18 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 step_args = pred_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): step_x = step.predict(copy.deepcopy(x), **step_args) out_x = self.concat(out_x, step_x, self.get_weights()[i]) else: raise TypeError(f"{step} is not an instance of TrainType") + if out_x is None: + raise ValueError("Predictions output is None.") + return out_x - def concat_labels(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401 + def concat_labels(self, original_data: _YT | None, data_to_concat: _YT, weight: float = 1.0) -> _YT: """Concatenate the transformed labels. Will use concat method if not overridden. :param original_data: The first input data. @@ -295,11 +311,11 @@ def concat_labels(self, original_data: Any, data_to_concat: Any, weight: float = :param weight: Weight of data to concat. :return: The concatenated data. """ - return self.concat(original_data, data_to_concat, weight) + return cast(_YT, self.concat(cast(_PT | None, original_data), cast(_PT, data_to_concat), weight)) @dataclass -class Pipeline(TrainType): +class Pipeline(TrainType[_XT, _YT, _PT]): """A pipeline of systems that can be trained and predicted. Parameters: @@ -344,11 +360,11 @@ def save_to_html(self, file_path: Path) -> None: predictions = pipeline.predict(x) """ - x_sys: TransformingSystem | None = None - y_sys: TransformingSystem | None = None - train_sys: Trainer | TrainingSystem | ParallelTrainingSystem | None = None - pred_sys: TransformingSystem | None = None - label_sys: TransformingSystem | None = None + x_sys: TransformingSystem[_XT, _XT] | None = None + y_sys: TransformingSystem[_YT, _YT] | None = None + train_sys: Trainer[_XT, _YT, _PT] | TrainingSystem[_XT, _YT, _PT] | ParallelTrainingSystem[_XT, _YT, _PT] | None = None + pred_sys: TransformingSystem[_XT | _PT, _XT | _PT] | None = None + label_sys: TransformingSystem[_YT, _YT] | None = None def __post_init__(self) -> None: """Post initialization function of the Pipeline.""" @@ -371,7 +387,7 @@ def __post_init__(self) -> None: self._set_children(children) - def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401 + def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the system. :param x: The input to the system. @@ -383,16 +399,18 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: x = self.x_sys.transform(x, **train_args.get("x_sys", {})) if self.y_sys is not None: y = self.y_sys.transform(y, **train_args.get("y_sys", {})) + + predictions: _XT | _PT = x if self.train_sys is not None: - x, y = self.train_sys.train(x, y, **train_args.get("train_sys", {})) + predictions, y = self.train_sys.train(x, y, **train_args.get("train_sys", {})) if self.pred_sys is not None: - x = self.pred_sys.transform(x, **train_args.get("pred_sys", {})) + predictions = self.pred_sys.transform(predictions, **train_args.get("pred_sys", {})) if self.label_sys is not None: y = self.label_sys.transform(y, **train_args.get("label_sys", {})) - return x, y + return predictions, y - def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 + def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the output of the system. :param x: The input to the system. @@ -401,12 +419,14 @@ def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401 """ if self.x_sys is not None: x = self.x_sys.transform(x, **pred_args.get("x_sys", {})) + + predictions: _XT | _PT = x if self.train_sys is not None: - x = self.train_sys.predict(x, **pred_args.get("train_sys", {})) + predictions = self.train_sys.predict(x, **pred_args.get("train_sys", {})) if self.pred_sys is not None: - x = self.pred_sys.transform(x, **pred_args.get("pred_sys", {})) + predictions = self.pred_sys.transform(predictions, **pred_args.get("pred_sys", {})) - return x + return predictions def _set_hash(self, prev_hash: str) -> None: """Set the hash of the pipeline. diff --git a/agogos/transforming.py b/agogos/transforming.py index 12c8f2b..6c25238 100644 --- a/agogos/transforming.py +++ b/agogos/transforming.py @@ -3,16 +3,19 @@ import copy import warnings from abc import abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar from agogos._core import _Base, _Block, _ParallelSystem, _SequentialSystem +_T = TypeVar("_T") +_R = TypeVar("_R") -class TransformType(_Base): + +class TransformType(Generic[_T, _R], _Base): """Abstract transform type describing a class that implements the transform function.""" @abstractmethod - def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 + def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. :param data: The input data. @@ -22,7 +25,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 raise NotImplementedError(f"{self.__class__.__name__} does not implement transform method.") -class Transformer(TransformType, _Block): +class Transformer(TransformType[_T, _R], _Block): """The transformer block transforms any data it could be x or y data. Methods: @@ -59,7 +62,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: """ -class TransformingSystem(TransformType, _SequentialSystem): +class TransformingSystem(TransformType[_T, _R], _SequentialSystem): """A system that transforms the input data. Parameters: @@ -99,12 +102,12 @@ def __post_init__(self) -> None: """Post init method for the TransformingSystem class.""" # Assert all steps are a subclass of Transformer for step in self.steps: - if not isinstance(step, (TransformType)): + if not isinstance(step, TransformType): raise TypeError(f"{step} is not an instance of TransformType") super().__post_init__() - def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 + def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. :param data: The input data. @@ -123,7 +126,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 step_name = step.__class__.__name__ step_args = transform_args.get(step_name, {}) - if isinstance(step, (TransformType)): + if isinstance(step, TransformType): data = step.transform(data, **step_args) else: raise TypeError(f"{step} is not an instance of TransformType") @@ -131,7 +134,7 @@ def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 return data -class ParallelTransformingSystem(TransformType, _ParallelSystem): +class ParallelTransformingSystem(TransformType[_T, _R], _ParallelSystem[_R]): """A system that transforms the input data in parallel. Parameters: @@ -179,31 +182,34 @@ def __post_init__(self) -> None: """Post init method for the ParallelTransformingSystem class.""" # Assert all steps are a subclass of Transformer or TransformingSystem for step in self.steps: - if not isinstance(step, (TransformType)): + if not isinstance(step, TransformType): raise TypeError(f"{step} is not an instance of TransformType") super().__post_init__() - def transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401 + def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. :param data: The input data. + :param transform_args: Additional arguments. :return: The transformed data. """ - # Loop through each step and call the transform method - out_data = None - if len(self.get_steps()) == 0: + if len(self.get_steps()) < 1: return data + out_data = None for i, step in enumerate(self.get_steps()): step_name = step.__class__.__name__ step_args = transform_args.get(step_name, {}) - if isinstance(step, (TransformType)): + if isinstance(step, TransformType): step_data = step.transform(copy.deepcopy(data), **step_args) out_data = self.concat(out_data, step_data, self.get_weights()[i]) else: raise TypeError(f"{step} is not an instance of TransformType") + if out_data is None: + raise ValueError("Transformation output is None.") + return out_data diff --git a/pyproject.toml b/pyproject.toml index 0bad43b..7da5d3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ classifiers = [ "Intended Audience :: Developers" ] dependencies= [ - "joblib>=1.4.0" + "joblib>=1.4.0", + "annotated-types>=0.7.0", ] [project.urls] diff --git a/requirements-dev.lock b/requirements-dev.lock index 2f867d3..7b3c50e 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -12,6 +12,8 @@ -e file:. alabaster==0.7.16 # via sphinx +annotated-types==0.7.0 + # via agogos babel==2.15.0 # via sphinx beautifulsoup4==4.12.3 diff --git a/requirements.lock b/requirements.lock index e2e90ba..bfa1e3e 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,5 +10,7 @@ --index-url https://pypi.org/simple/ -e file:. +annotated-types==0.7.0 + # via agogos joblib==1.4.2 # via agogos From c9e10b14e37e271e88f2d81e0689f3d3f2fceee1 Mon Sep 17 00:00:00 2001 From: Jeffrey Lim Date: Fri, 28 Jun 2024 12:01:57 +0200 Subject: [PATCH 3/4] Fix most failing tests --- agogos/_core.py | 9 ++++----- agogos/training.py | 9 ++------- agogos/transforming.py | 1 - tests/test__core.py | 22 +++++++++++----------- tests/test_training.py | 8 ++++---- tests/test_transforming.py | 10 +++++----- 6 files changed, 26 insertions(+), 33 deletions(-) diff --git a/agogos/_core.py b/agogos/_core.py index a355345..ad0146a 100644 --- a/agogos/_core.py +++ b/agogos/_core.py @@ -33,9 +33,9 @@ def save_to_html(self, file_path: Path) -> None: # Save html format to file_path """ - _parent: _Base | None = None - _children: Iterable[_Base] = field(default_factory=list) - _hash: str = field(init=False) + _parent: _Base | None = field(default=None, init=False) + _children: Iterable[_Base] = field(default_factory=list, init=False) + _hash: str = field(default="", init=False) def __post_init__(self) -> None: """Initialize the block.""" @@ -136,7 +136,7 @@ def save_to_html(self, file_path: Path) -> None: @dataclass -class _ParallelSystem(ABC, Generic[_DT], _Base): +class _ParallelSystem(Generic[_DT], _Base): """The _System class is the base class for all systems. Parameters: @@ -228,7 +228,6 @@ def _set_hash(self, prev_hash: str) -> None: self._hash = hash(total) - @abstractmethod def concat(self, original_data: _DT | None, data_to_concat: _DT, weight: float = 1.0) -> _DT: """Concatenate the transformed data. diff --git a/agogos/training.py b/agogos/training.py index 1e747cd..944c0c7 100644 --- a/agogos/training.py +++ b/agogos/training.py @@ -16,11 +16,9 @@ _PT = TypeVar("_PT") -@dataclass -class TrainType(ABC, Generic[_XT, _YT, _PT], _Base): +class TrainType(Generic[_XT, _YT, _PT], _Base): """Abstract train type describing a class that implements two functions train and predict.""" - @abstractmethod def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """Train the block. @@ -31,7 +29,6 @@ def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_XT | _PT, _YT]: """ raise NotImplementedError(f"{self.__class__.__name__} does not implement train method.") - @abstractmethod def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: """Predict the target variable. @@ -42,7 +39,6 @@ def predict(self, x: _XT, **pred_args: Any) -> _XT | _PT: raise NotImplementedError(f"{self.__class__.__name__} does not implement predict method.") -@dataclass class Trainer(TrainType[_XT, _YT, _PT], _Block): """The trainer block is for blocks that need to train on two inputs and predict on one. @@ -89,7 +85,6 @@ def predict(self, x: Any, **pred_args: Any) -> Any: """ -@dataclass class TrainingSystem(TrainType[_XT, _YT, _PT], _SequentialSystem): """A system that trains on the input data and labels. @@ -265,7 +260,7 @@ def train(self, x: _XT, y: _YT, **train_args: Any) -> tuple[_PT, _YT]: step_args = train_args.get(step_name, {}) - if isinstance(step, (TrainType)): + if isinstance(step, TrainType): step_x, step_y = step.train(copy.deepcopy(x), copy.deepcopy(y), **step_args) out_x, out_y = ( self.concat(out_x, step_x, self.get_weights()[i]), diff --git a/agogos/transforming.py b/agogos/transforming.py index 6c25238..63de149 100644 --- a/agogos/transforming.py +++ b/agogos/transforming.py @@ -14,7 +14,6 @@ class TransformType(Generic[_T, _R], _Base): """Abstract transform type describing a class that implements the transform function.""" - @abstractmethod def transform(self, data: _T, **transform_args: Any) -> _T | _R: """Transform the input data. diff --git a/tests/test__core.py b/tests/test__core.py index c59470b..50ef687 100644 --- a/tests/test__core.py +++ b/tests/test__core.py @@ -9,7 +9,7 @@ def test_init(self): assert base is not None def test_get_hash(self): - assert _Base().get_hash() == "be5a33685928d3da88062f187295a019" + assert _Base().get_hash() == "e75c7852dd3d36ffc2f1b90efa9568d8" def test_set_hash(self): base = _Base() @@ -39,7 +39,7 @@ def test__repr_html_(self): base = _Base() assert ( base._repr_html_() - == "

Class: _Base

  • Hash: be5a33685928d3da88062f187295a019
  • Parent: None
  • Children: None
" + == "

Class: _Base

  • Hash: e75c7852dd3d36ffc2f1b90efa9568d8
  • Parent: None
  • Children: None
" ) def test_save_to_html(self): @@ -60,17 +60,17 @@ def test_block_set_hash(self): block = _Block() block._set_hash("") hash1 = block.get_hash() - assert hash1 == "04714d9ee40c9baff8c528ed982a103c" + assert hash1 == "8c52898e95f367f12e9079cc62d141cb" block._set_hash(hash1) hash2 = block.get_hash() - assert hash2 == "83196595c42f8eff9218c0ac8f80faf0" + assert hash2 == "243051d47d23a36250057824eed90525" assert hash1 != hash2 def test_block_get_hash(self): block = _Block() block._set_hash("") hash1 = block.get_hash() - assert hash1 == "04714d9ee40c9baff8c528ed982a103c" + assert hash1 == "8c52898e95f367f12e9079cc62d141cb" def test__repr_html_(self): block_instance = _Block() @@ -93,7 +93,7 @@ def test_system_hash_with_1_step(self): block1 = _Block() system = _SequentialSystem([block1]) - assert system.get_hash() == "04714d9ee40c9baff8c528ed982a103c" + assert system.get_hash() == "8c52898e95f367f12e9079cc62d141cb" assert block1.get_hash() == system.get_hash() def test_system_hash_with_2_steps(self): @@ -103,7 +103,7 @@ def test_system_hash_with_2_steps(self): system = _SequentialSystem([block1, block2]) assert system.get_hash() != block1.get_hash() assert ( - system.get_hash() == block2.get_hash() == "83196595c42f8eff9218c0ac8f80faf0" + system.get_hash() == block2.get_hash() == "a60a4a1d474b8454b0ee9197875171f6" ) def test_system_hash_with_3_steps(self): @@ -116,7 +116,7 @@ def test_system_hash_with_3_steps(self): assert system.get_hash() != block2.get_hash() assert block1.get_hash() != block2.get_hash() assert ( - system.get_hash() == block3.get_hash() == "5aaa5f0962baedf36f132ad39380761e" + system.get_hash() == block3.get_hash() == "dfebfe0e709805bd8bc328d115400929" ) def test__repr_html_(self): @@ -140,7 +140,7 @@ def test_parallel_system_hash_with_1_step(self): block1 = _Block() system = _ParallelSystem([block1]) - assert system.get_hash() == "04714d9ee40c9baff8c528ed982a103c" + assert system.get_hash() == "8c52898e95f367f12e9079cc62d141cb" assert block1.get_hash() == system.get_hash() def test_parallel_system_hash_with_2_steps(self): @@ -151,7 +151,7 @@ def test_parallel_system_hash_with_2_steps(self): assert system.get_hash() != block1.get_hash() assert block1.get_hash() == block2.get_hash() assert system.get_hash() != block2.get_hash() - assert system.get_hash() == "9689e0f292013df811f8e910684406f7" + assert system.get_hash() == "6430f15bf7782822914896704610d45d" def test_parallel_system_hash_with_3_steps(self): block1 = _Block() @@ -163,7 +163,7 @@ def test_parallel_system_hash_with_3_steps(self): assert system.get_hash() != block2.get_hash() assert system.get_hash() != block3.get_hash() assert block1.get_hash() == block2.get_hash() == block3.get_hash() - assert system.get_hash() == "b5ea75f99dbfb82c35e082c88b94bda7" + assert system.get_hash() == "e4e48c6863f83ed0e876357da5a6ed24" def test_parallel_system__repr_html_(self): block_instance = _Block() diff --git a/tests/test_training.py b/tests/test_training.py index 7f3e154..abf6b29 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -34,7 +34,7 @@ def predict(self, x): def test_trainer_hash(self): trainer = Trainer() - assert trainer.get_hash() == "0a1fcf1d677d4a1f3f082aa85ffcb684" + assert trainer.get_hash() == "7e416fe4f426f9b52e95a62a93972fa7" class TestTrainingSystem: @@ -569,7 +569,7 @@ def transform(self, x): pred_sys=prediction_system, ) assert x_system.get_hash() != y_system.get_hash() - assert pipeline.get_hash() == "787438b6940d465d122113444249eaa4" + assert pipeline.get_hash() == "dcfb09379dbd5f2b1ba6fd9430e2558a" def test_pipeline_predict_system_hash(self): class TransformingBlock(Transformer): @@ -588,7 +588,7 @@ def transform(self, x): pred_sys=prediction_system, ) assert prediction_system.get_hash() != x_system.get_hash() - assert pipeline.get_hash() == "842e1162d744e7ab09c941300a43c218" + assert pipeline.get_hash() == "df836b9fd09c38916f256b43c368caa7" def test_pipeline_pre_post_hash(self): class TransformingBlock(Transformer): @@ -607,4 +607,4 @@ def transform(self, x): pred_sys=prediction_system, ) assert x_system.get_hash() != prediction_system.get_hash() - assert pipeline.get_hash() == "8a0be6742040f6a05d805ac79b486f6c" + assert pipeline.get_hash() == "7ea1ca20dd1478a47bc44c2427a09520" diff --git a/tests/test_transforming.py b/tests/test_transforming.py index 5cf59cc..0bc96d5 100644 --- a/tests/test_transforming.py +++ b/tests/test_transforming.py @@ -28,7 +28,7 @@ def transform(self, data): def test_transformer_hash(self): transformer = Transformer() - assert transformer.get_hash() == "1cbcc4f2d0921b050d9b719d2beb6529" + assert transformer.get_hash() == "6afa3259ad9293ce2c3ff6ca58ab8d68" class TestTransformingSystem: @@ -300,12 +300,12 @@ def transform(self, data): system.transform([1, 2, 3]) def test_transform_parallel_hashes(self): - class SubTransformer1(Transformer): - def transform(self, x): + class SubTransformer1(Transformer[int, int]): + def transform(self, x: int, **kwargs) -> int: return x - class SubTransformer2(Transformer): - def transform(self, x): + class SubTransformer2(Transformer[int, int]): + def transform(self, x: int, **kwargs) -> int: return x * 2 block1 = SubTransformer1() From 6d6b74b4dad19e9d5944b2892bd158cef66994d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Jun 2024 10:08:22 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- agogos/_core.py | 1 - agogos/training.py | 1 - agogos/transforming.py | 1 - 3 files changed, 3 deletions(-) diff --git a/agogos/_core.py b/agogos/_core.py index ad0146a..d2d8771 100644 --- a/agogos/_core.py +++ b/agogos/_core.py @@ -2,7 +2,6 @@ from __future__ import annotations -from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from pathlib import Path diff --git a/agogos/training.py b/agogos/training.py index 944c0c7..1c64e7d 100644 --- a/agogos/training.py +++ b/agogos/training.py @@ -2,7 +2,6 @@ import copy import warnings -from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Generic, TypeVar, cast diff --git a/agogos/transforming.py b/agogos/transforming.py index 63de149..ba416c0 100644 --- a/agogos/transforming.py +++ b/agogos/transforming.py @@ -2,7 +2,6 @@ import copy import warnings -from abc import abstractmethod from typing import Any, Generic, TypeVar from agogos._core import _Base, _Block, _ParallelSystem, _SequentialSystem