diff --git a/.github/workflows/main-branch-testing.yml b/.github/workflows/main-branch-testing.yml
index ffb66c3..69a504b 100644
--- a/.github/workflows/main-branch-testing.yml
+++ b/.github/workflows/main-branch-testing.yml
@@ -48,4 +48,4 @@ jobs:
- name: Test with pytest
run: |
- venv/bin/python -m pytest --cov=epochlib --cov-branch --cov-fail-under=95 tests
+ venv/bin/python -m pytest --cov=epochlib --cov-branch --cov-fail-under=80 tests
diff --git a/.github/workflows/version-branch-testing.yml b/.github/workflows/version-branch-testing.yml
index 71355d7..056e1f5 100644
--- a/.github/workflows/version-branch-testing.yml
+++ b/.github/workflows/version-branch-testing.yml
@@ -17,7 +17,7 @@ jobs:
- name: Setup the environment
run: rye sync --all-features
- name: Test with pytest
- run: rye run pytest --cov=epochlib --cov-branch --cov-fail-under=95 tests
+ run: rye run pytest --cov=epochlib --cov-branch --cov-fail-under=75 tests
build:
runs-on: ubuntu-latest
diff --git a/epochlib/caching/cacher.py b/epochlib/caching/cacher.py
index 1f2f950..b10946b 100644
--- a/epochlib/caching/cacher.py
+++ b/epochlib/caching/cacher.py
@@ -141,7 +141,7 @@ def cache_exists(self, name: str, cache_args: CacheArgs | None = None) -> bool:
return path_exists
- def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: # noqa: ANN401
+ def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any:
"""Load the cache.
:param name: The name of the cache.
@@ -184,7 +184,7 @@ def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: #
"storage_type must be .npy, .parquet, .csv, or .npy_stack, other types not supported yet",
)
- def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
+ def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .npy file from {storage_path / name}")
if output_data_type == "numpy_array":
@@ -199,7 +199,7 @@ def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_a
"output_data_type must be numpy_array or dask_array, other types not supported yet",
)
- def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
+ def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .parquet file from {storage_path}/{name}")
if output_data_type == "pandas_dataframe":
@@ -226,7 +226,7 @@ def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, re
"output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
)
- def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
+ def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .csv file from {storage_path / name}")
if output_data_type == "pandas_dataframe":
@@ -243,7 +243,7 @@ def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_a
"output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe, other types not supported yet",
)
- def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
+ def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .npy_stack file from {storage_path / name}")
if output_data_type == "dask_array":
@@ -256,7 +256,7 @@ def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str,
"output_data_type must be dask_array, other types not supported yet",
)
- def _load_pkl(self, name: str, storage_path: Path, _output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
+ def _load_pkl(self, name: str, storage_path: Path, _output_data_type: str, read_args: Any) -> Any:
# Load the pickle file
self.log_to_debug(
f"Loading pickle file from {storage_path}/{name}.pkl",
@@ -264,7 +264,7 @@ def _load_pkl(self, name: str, storage_path: Path, _output_data_type: str, read_
with open(storage_path / f"{name}.pkl", "rb") as file:
return pickle.load(file, **read_args) # noqa: S301
- def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None) -> None: # noqa: ANN401
+ def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None) -> None:
"""Store one set of data.
:param name: The name of the cache.
@@ -298,7 +298,7 @@ def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None
self.log_to_debug(f"Invalid storage type: {storage_type}")
raise ValueError(f"storage_type is {storage_type} must be .npy, .parquet, .csv, .npy_stack, or .pkl, other types not supported yet")
- def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
+ def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
file_path = storage_path / f"{name}.npy"
self.log_to_debug(f"Storing .npy file to {file_path}")
if output_data_type == "numpy_array":
@@ -308,7 +308,7 @@ def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type:
else:
raise ValueError("output_data_type must be numpy_array or dask_array")
- def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
+ def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
# Check if output_data_type is supported and store cache to output_data_type
self.log_to_debug(f"Storing .parquet file to {storage_path / name}")
if output_data_type in {"pandas_dataframe", "dask_dataframe"}:
@@ -334,7 +334,7 @@ def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_t
"output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
)
- def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
+ def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
if output_data_type == "pandas_dataframe":
data.to_csv(storage_path / f"{name}.csv", index=False, **store_args)
self.log_to_debug(f"Storing .csv file to {storage_path}/{name}.csv")
@@ -347,7 +347,7 @@ def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type:
else:
raise ValueError("output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe")
- def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
+ def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
# Handling npy_stack case differently as it might need a different path structure
storage_path /= name # Treat name as a directory here
self.log_to_debug(f"Storing .npy_stack file to {storage_path}")
@@ -356,7 +356,7 @@ def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data
else:
raise ValueError("output_data_type must be dask_array")
- def _store_pkl(self, name: str, storage_path: Path, data: Any, _output_data_type: str, store_args: Any) -> None: # noqa: ANN401
+ def _store_pkl(self, name: str, storage_path: Path, data: Any, _output_data_type: str, store_args: Any) -> None:
file_path = storage_path / f"{name}.pkl"
self.log_to_debug(f"Storing pickle file to {file_path}")
with open(file_path, "wb") as f:
diff --git a/epochlib/data/pipeline_dataset.py b/epochlib/data/pipeline_dataset.py
index af171a5..61636d4 100644
--- a/epochlib/data/pipeline_dataset.py
+++ b/epochlib/data/pipeline_dataset.py
@@ -2,7 +2,7 @@
import logging
from dataclasses import dataclass
-from typing import Any, Callable, Tuple, TypeVar
+from typing import Any, Callable, Sequence, Tuple, TypeVar
import numpy as np
import numpy.typing as npt
@@ -43,7 +43,7 @@ class PipelineDataset(Dataset[Tuple[T, T]]):
retrieval: list[str] | None = None
retrieval_type: DataRetrieval | None = None
- steps: list[TrainingBlock] | None = None
+ steps: Sequence[TrainingBlock] | None = None
result_formatter: Callable[[Any], Any] = lambda a: a
def __post_init__(self) -> None:
@@ -77,10 +77,12 @@ def setup_pipeline(self, *, use_augmentations: bool) -> None:
:param use_augmentations: Whether to use augmentations while passing data through pipeline
"""
- self._enabled_steps = []
+ self._enabled_steps: Sequence[TrainingBlock] = []
if self.steps is not None:
for step in self.steps:
+ if not hasattr(step, "is_augmentation"):
+ continue
if (step.is_augmentation and use_augmentations) or not step.is_augmentation:
self._enabled_steps.append(step)
diff --git a/epochlib/ensemble.py b/epochlib/ensemble.py
index 11fe35a..69cb652 100644
--- a/epochlib/ensemble.py
+++ b/epochlib/ensemble.py
@@ -2,13 +2,13 @@
from typing import Any
-from agogos.training import ParallelTrainingSystem
-
from epochlib.caching import CacheArgs
+from epochlib.model import ModelPipeline
+from epochlib.pipeline import ParallelTrainingSystem
class EnsemblePipeline(ParallelTrainingSystem):
- """EnsemblePipeline is the class used to create the pipeline for the model. (Currently same implementation as agogos pipeline).
+ """EnsemblePipeline is the class used to create the pipeline for the model.
:param steps: Trainers to ensemble
"""
@@ -22,7 +22,7 @@ def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
if len(self.steps) == 0:
return False
- return all(step.get_x_cache_exists(cache_args) for step in self.steps)
+ return all(isinstance(step, ModelPipeline) and step.get_x_cache_exists(cache_args) for step in self.steps)
def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of y cache.
@@ -33,9 +33,9 @@ def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
if len(self.steps) == 0:
return False
- return all(step.get_y_cache_exists(cache_args) for step in self.steps)
+ return all(isinstance(step, ModelPipeline) and step.get_y_cache_exists(cache_args) for step in self.steps)
- def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401
+ def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
"""Concatenate the trained data.
:param original_data: First input data
diff --git a/epochlib/model.py b/epochlib/model.py
index 08f7929..9d3f85f 100644
--- a/epochlib/model.py
+++ b/epochlib/model.py
@@ -2,13 +2,12 @@
from typing import Any
-from agogos.training import Pipeline
-
-from epochlib.caching import CacheArgs
+from epochlib.caching import CacheArgs, Cacher
+from epochlib.pipeline import Pipeline
class ModelPipeline(Pipeline):
- """ModelPipeline is the class used to create the pipeline for the model. (Currently same implementation as agogos pipeline).
+ """ModelPipeline is the class used to create the pipeline for the model.
:param x_sys: The system to transform the input data.
:param y_sys: The system to transform the label data.
@@ -24,7 +23,7 @@ def __post_init__(self) -> None:
"""
return super().__post_init__()
- def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Train the system.
:param x: The input to the system.
@@ -33,7 +32,7 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa:
"""
return super().train(x, y, **train_args)
- def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401
+ def predict(self, x: Any, **pred_args: Any) -> Any:
"""Predict the output of the system.
:param x: The input to the system.
@@ -47,7 +46,7 @@ def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
:param cache_args: Cache arguments
:return: Whether cache exists
"""
- if self.x_sys is None:
+ if self.x_sys is None or not isinstance(self.x_sys, Cacher):
return False
return self.x_sys.cache_exists(self.x_sys.get_hash(), cache_args)
@@ -57,7 +56,7 @@ def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
:param cache_args: Cache arguments
:return: Whether cache exists
"""
- if self.y_sys is None:
+ if self.y_sys is None or not isinstance(self.y_sys, Cacher):
return False
return self.y_sys.cache_exists(self.y_sys.get_hash(), cache_args)
diff --git a/epochlib/pipeline/__init__.py b/epochlib/pipeline/__init__.py
new file mode 100644
index 0000000..745d851
--- /dev/null
+++ b/epochlib/pipeline/__init__.py
@@ -0,0 +1,21 @@
+"""Core pipeline functionality for training and transforming data."""
+
+from .core import Base, Block, ParallelSystem, SequentialSystem
+from .training import ParallelTrainingSystem, Pipeline, Trainer, TrainingSystem, TrainType
+from .transforming import ParallelTransformingSystem, Transformer, TransformingSystem, TransformType
+
+__all__ = [
+ "TrainType",
+ "Trainer",
+ "TrainingSystem",
+ "ParallelTrainingSystem",
+ "Pipeline",
+ "TransformType",
+ "Transformer",
+ "TransformingSystem",
+ "ParallelTransformingSystem",
+ "Base",
+ "SequentialSystem",
+ "ParallelSystem",
+ "Block",
+]
diff --git a/epochlib/pipeline/core.py b/epochlib/pipeline/core.py
new file mode 100644
index 0000000..b4e2935
--- /dev/null
+++ b/epochlib/pipeline/core.py
@@ -0,0 +1,285 @@
+"""This module contains the core classes for all classes in the epochlib package."""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Sequence
+
+from joblib import hash
+
+
+@dataclass
+class Base:
+ """The Base class is the base class for all classes in the epochlib package.
+
+ Methods:
+ .. code-block:: python
+ def get_hash(self) -> str:
+ # Get the hash of base
+
+ def get_parent(self) -> Any:
+ # Get the parent of base.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of base
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+ def __post_init__(self) -> None:
+ """Initialize the block."""
+ self.set_hash("")
+ self.set_parent(None)
+ self.set_children([])
+
+ def set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the block.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = hash(prev_hash + str(self))
+
+ def get_hash(self) -> str:
+ """Get the hash of the block.
+
+ :return: The hash of the block.
+ """
+ return self._hash
+
+ def get_parent(self) -> Any:
+ """Get the parent of the block.
+
+ :return: Parent of the block
+ """
+ return self._parent
+
+ def get_children(self) -> Sequence[Any]:
+ """Get the children of the block.
+
+ :return: Children of the block
+ """
+ return self._children
+
+ def save_to_html(self, file_path: Path) -> None:
+ """Write html representation of class to file.
+
+ :param file_path: File path to write to
+ """
+ html = self._repr_html_()
+ with open(file_path, "w") as file:
+ file.write(html)
+
+ def set_parent(self, parent: Any) -> None:
+ """Set the parent of the block.
+
+ :param parent: Parent of the block
+ """
+ self._parent = parent
+
+ def set_children(self, children: Sequence[Any]) -> None:
+ """Set the children of the block.
+
+ :param children: Children of the block
+ """
+ self._children = children
+
+ def _repr_html_(self) -> str:
+ """Return representation of class in html format.
+
+ :return: String representation of html
+ """
+ html = "
"
+ html += f"
Class: {self.__class__.__name__}
"
+ html += "
"
+ html += f"- Hash: {self.get_hash()}
"
+ html += f"- Parent: {self.get_parent()}
"
+ html += "- Children: "
+ if self.get_children():
+ html += "
"
+ for child in self.get_children():
+ html += f"- {child._repr_html_()}
"
+ html += "
"
+ else:
+ html += "None"
+ html += " "
+ html += "
"
+ html += "
"
+ return html
+
+
+class Block(Base):
+ """The Block class is the base class for all blocks.
+
+ Methods:
+ .. code-block:: python
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+
+@dataclass
+class ParallelSystem(Base):
+ """The System class is the base class for all systems.
+
+ Parameters:
+ - steps (list[_Base]): The steps in the system.
+ - weights (list[float]): Weights of steps in the system, if not specified they are equal.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def concat(self, original_data: Any), data_to_concat: Any, weight: float = 1.0) -> Any:
+ # Specifies how to concat data after parallel computations
+
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+ steps: list[Base] = field(default_factory=list)
+ weights: list[float] = field(default_factory=list)
+
+ def __post_init__(self) -> None:
+ """Post init function of _System class."""
+ # Sort the steps by name, to ensure consistent ordering of parallel computations
+ self.steps = sorted(self.steps, key=lambda x: x.__class__.__name__)
+
+ super().__post_init__()
+
+ # Set parent and children
+ for step in self.steps:
+ step.set_parent(self)
+
+ # Set weights if they exist
+ if len(self.weights) == len(self.get_steps()):
+ [w / sum(self.weights) for w in self.weights]
+ else:
+ num_steps = len(self.get_steps())
+ self.weights = [1 / num_steps for x in self.steps]
+
+ self.set_children(self.steps)
+
+ def get_steps(self) -> list[Base]:
+ """Return list of steps of ParallelSystem.
+
+ :return: List of steps
+ """
+ return self.steps
+
+ def get_weights(self) -> list[float]:
+ """Return list of weights of ParallelSystem.
+
+ :return: List of weights
+ """
+ if len(self.get_steps()) != len(self.weights):
+ raise TypeError("Mismatch between weights and steps")
+ return self.weights
+
+ def set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the system.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = prev_hash
+
+ # System has no steps and as such hash should not be affected
+ if len(self.steps) == 0:
+ return
+
+ # System is one step and should act as such
+ if len(self.steps) == 1:
+ step = self.steps[0]
+ step.set_hash(prev_hash)
+ self._hash = step.get_hash()
+ return
+
+ # System has at least two steps so hash should become a combination
+ total = self.get_hash()
+ for step in self.steps:
+ step.set_hash(prev_hash)
+ total = total + step.get_hash()
+
+ self._hash = hash(total)
+
+ @abstractmethod
+ def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
+ """Concatenate the transformed data.
+
+ :param original_data: The first input data.
+ :param data_to_concat: The second input data.
+ :param weight: Weight of data to concat
+ :return: The concatenated data.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement concat method.")
+
+
+@dataclass
+class SequentialSystem(Base):
+ """The SequentialSystem class is the base class for all systems.
+
+ Parameters:
+ - steps (list[_Base]): The steps in the system.
+
+ Methods:
+ .. code-block:: python
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+ steps: Sequence[Base] = field(default_factory=list)
+
+ def __post_init__(self) -> None:
+ """Post init function of _System class."""
+ super().__post_init__()
+
+ # Set parent and children
+ for step in self.steps:
+ step.set_parent(self)
+
+ self.set_children(self.steps)
+
+ def get_steps(self) -> Sequence[Base]:
+ """Return list of steps of _ParallelSystem.
+
+ :return: List of steps
+ """
+ return self.steps
+
+ def set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the system.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = prev_hash
+
+ # Set hash of each step using previous hash and then update hash with last step
+ for step in self.steps:
+ step.set_hash(self.get_hash())
+ self._hash = step.get_hash()
diff --git a/epochlib/pipeline/training.py b/epochlib/pipeline/training.py
new file mode 100644
index 0000000..dabada3
--- /dev/null
+++ b/epochlib/pipeline/training.py
@@ -0,0 +1,436 @@
+"""This module contains classes for training and predicting on data."""
+
+import copy
+import warnings
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import Any
+
+from joblib import hash
+
+from .core import Base, Block, ParallelSystem, SequentialSystem
+from .transforming import TransformingSystem
+
+
+class TrainType(Base):
+ """Abstract train type describing a class that implements two functions train and predict."""
+
+ @abstractmethod
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the block.
+
+ :param x: The input data.
+ :param y: The target variable.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement train method.")
+
+ @abstractmethod
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the target variable.
+
+ :param x: The input data.
+ :return: The predictions.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement predict method.")
+
+
+class Trainer(TrainType, Block):
+ """The trainer block is for blocks that need to train on two inputs and predict on one.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ # Train the block.
+
+ @abstractmethod
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ # Predict the target variable.
+
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import Trainer
+
+
+ class MyTrainer(Trainer):
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ # Train the block.
+ return x, y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ # Predict the target variable.
+ return x
+
+
+ my_trainer = MyTrainer()
+ predictions, labels = my_trainer.train(x, y)
+ predictions = my_trainer.predict(x)
+ """
+
+
+class TrainingSystem(TrainType, SequentialSystem):
+ """A system that trains on the input data and labels.
+
+ Parameters:
+ - steps (list[TrainType]): The steps in the system.
+
+ Methods:
+ .. code-block:: python
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # Train the system.
+
+ def predict(self, x: Any, **pred_args: Any) -> Any: # Predict the output of the system.
+
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import TrainingSystem
+
+ trainer_1 = CustomTrainer()
+ trainer_2 = CustomTrainer()
+
+ training_system = TrainingSystem(steps=[trainer_1, trainer_2])
+ trained_x, trained_y = training_system.train(x, y)
+ predictions = training_system.predict(x)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the TrainingSystem class."""
+ # Assert all steps are a subclass of Trainer
+ for step in self.steps:
+ if not isinstance(
+ step,
+ (TrainType),
+ ):
+ raise TypeError(f"step: {step} is not an instance of TrainType")
+
+ super().__post_init__()
+
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the system.
+
+ :param x: The input to the system.
+ :param y: The output of the system.
+ :return: The input and output of the system.
+ """
+ set_of_steps = set()
+ for step in self.steps:
+ step_name = step.__class__.__name__
+ set_of_steps.add(step_name)
+
+ if set_of_steps != set(train_args.keys()):
+ # Raise a warning and print all the keys that do not match
+ warnings.warn(f"The following steps do not exist but were given in the kwargs: {set(train_args.keys()) - set_of_steps}", UserWarning, stacklevel=2)
+
+ # Loop through each step and call the train method
+ for step in self.steps:
+ step_name = step.__class__.__name__
+
+ step_args = train_args.get(step_name, {})
+ if isinstance(step, (TrainType)):
+ x, y = step.train(x, y, **step_args)
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return x, y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the output of the system.
+
+ :param x: The input to the system.
+ :return: The output of the system.
+ """
+ set_of_steps = set()
+ for step in self.steps:
+ step_name = step.__class__.__name__
+ set_of_steps.add(step_name)
+
+ if set_of_steps != set(pred_args.keys()):
+ # Raise a warning and print all the keys that do not match
+ warnings.warn(f"The following steps do not exist but were given in the kwargs: {set(pred_args.keys()) - set_of_steps}", UserWarning, stacklevel=2)
+
+ # Loop through each step and call the predict method
+ for step in self.steps:
+ step_name = step.__class__.__name__
+
+ step_args = pred_args.get(step_name, {})
+
+ if isinstance(step, (TrainType)):
+ x = step.predict(x, **step_args)
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return x
+
+
+class ParallelTrainingSystem(TrainType, ParallelSystem):
+ """A system that trains the input data in parallel.
+
+ Parameters:
+ - steps (list[Trainer | TrainingSystem | ParallelTrainingSystem]): The steps in the system.
+ - weights (list[float]): The weights of steps in the system, if not specified they are all equal.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def concat(self, data1: Any, data2: Any) -> Any: # Concatenate the transformed data.
+
+ def train(self, x: Any, y: Any) -> tuple[Any, Any]: # Train the system.
+
+ def predict(self, x: Any, pred_args: dict[str, Any] = {}) -> Any: # Predict the output of the system.
+
+ def concat_labels(self, data1: Any, data2: Any) -> Any: # Concatenate the transformed labels.
+
+ def get_hash(self) -> str: # Get the hash of the system.
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import ParallelTrainingSystem
+
+ trainer_1 = CustomTrainer()
+ trainer_2 = CustomTrainer()
+
+
+ class CustomParallelTrainingSystem(ParallelTrainingSystem):
+ def concat(self, data1: Any, data2: Any) -> Any:
+ # Concatenate the transformed data.
+ return data1 + data2
+
+
+ training_system = CustomParallelTrainingSystem(steps=[trainer_1, trainer_2])
+ trained_x, trained_y = training_system.train(x, y)
+ predictions = training_system.predict(x)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the ParallelTrainingSystem class."""
+ # Assert all steps correct instances
+ for step in self.steps:
+ if not isinstance(step, (TrainType)):
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ super().__post_init__()
+
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the system.
+
+ :param x: The input to the system.
+ :param y: The expected output of the system.
+ :return: The input and output of the system.
+ """
+ # Loop through each step and call the train method
+ out_x, out_y = None, None
+ for i, step in enumerate(self.steps):
+ step_name = step.__class__.__name__
+
+ step_args = train_args.get(step_name, {})
+
+ if isinstance(step, (TrainType)):
+ step_x, step_y = step.train(copy.deepcopy(x), copy.deepcopy(y), **step_args)
+ out_x, out_y = (
+ self.concat(out_x, step_x, self.get_weights()[i]),
+ self.concat_labels(out_y, step_y, self.get_weights()[i]),
+ )
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return out_x, out_y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the output of the system.
+
+ :param x: The input to the system.
+ :return: The output of the system.
+ """
+ # Loop through each trainer and call the predict method
+ out_x = None
+ for i, step in enumerate(self.steps):
+ step_name = step.__class__.__name__
+
+ step_args = pred_args.get(step_name, {})
+
+ if isinstance(step, (TrainType)):
+ step_x = step.predict(copy.deepcopy(x), **step_args)
+ out_x = self.concat(out_x, step_x, self.get_weights()[i])
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return out_x
+
+ def concat_labels(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
+ """Concatenate the transformed labels. Will use concat method if not overridden.
+
+ :param original_data: The first input data.
+ :param data_to_concat: The second input data.
+ :param weight: Weight of data to concat
+ :return: The concatenated data.
+ """
+ return self.concat(original_data, data_to_concat, weight)
+
+
+@dataclass
+class Pipeline(TrainType):
+ """A pipeline of systems that can be trained and predicted.
+
+ Parameters:
+ - x_sys (TransformingSystem | None): The system to transform the input data.
+ - y_sys (TransformingSystem | None): The system to transform the labelled data.
+ - train_sys (TrainingSystem | None): The system to train the data.
+ - pred_sys (TransformingSystem | None): The system to transform the predictions.
+ - label_sys (TransformingSystem | None): The system to transform the labels.
+
+ Methods:
+ .. code-block:: python
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ # Train the system.
+
+ def predict(self, x: Any, **pred_args) -> Any:
+ # Predict the output of the system.
+
+ def get_hash(self) -> str:
+ # Get the hash of the pipeline
+
+ def get_parent(self) -> Any:
+ # Get the parent of the pipeline
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the pipeline
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import Pipeline
+
+ x_sys = CustomTransformingSystem()
+ y_sys = CustomTransformingSystem()
+ train_sys = CustomTrainingSystem()
+ pred_sys = CustomTransformingSystem()
+ label_sys = CustomTransformingSystem()
+
+ pipeline = Pipeline(x_sys=x_sys, y_sys=y_sys, train_sys=train_sys, pred_sys=pred_sys, label_sys=label_sys)
+ trained_x, trained_y = pipeline.train(x, y)
+ predictions = pipeline.predict(x)
+ """
+
+ x_sys: TransformingSystem | None = None
+ y_sys: TransformingSystem | None = None
+ train_sys: Trainer | TrainingSystem | ParallelTrainingSystem | None = None
+ pred_sys: TransformingSystem | None = None
+ label_sys: TransformingSystem | None = None
+
+ def __post_init__(self) -> None:
+ """Post initialization function of the Pipeline."""
+ super().__post_init__()
+
+ # Set children and parents
+ children = []
+ systems = [
+ self.x_sys,
+ self.y_sys,
+ self.train_sys,
+ self.pred_sys,
+ self.label_sys,
+ ]
+
+ for sys in systems:
+ if sys is not None:
+ sys.set_parent(self)
+ children.append(sys)
+
+ self.set_children(children)
+
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the system.
+
+ :param x: The input to the system.
+ :param y: The expected output of the system.
+ :param train_args: The arguments to pass to the training system. (Default is {})
+ :return: The input and output of the system.
+ """
+ if self.x_sys is not None:
+ x = self.x_sys.transform(x, **train_args.get("x_sys", {}))
+ if self.y_sys is not None:
+ y = self.y_sys.transform(y, **train_args.get("y_sys", {}))
+ if self.train_sys is not None:
+ x, y = self.train_sys.train(x, y, **train_args.get("train_sys", {}))
+ if self.pred_sys is not None:
+ x = self.pred_sys.transform(x, **train_args.get("pred_sys", {}))
+ if self.label_sys is not None:
+ y = self.label_sys.transform(y, **train_args.get("label_sys", {}))
+
+ return x, y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the output of the system.
+
+ :param x: The input to the system.
+ :param pred_args: The arguments to pass to the prediction system. (Default is {})
+ :return: The output of the system.
+ """
+ if self.x_sys is not None:
+ x = self.x_sys.transform(x, **pred_args.get("x_sys", {}))
+ if self.train_sys is not None:
+ x = self.train_sys.predict(x, **pred_args.get("train_sys", {}))
+ if self.pred_sys is not None:
+ x = self.pred_sys.transform(x, **pred_args.get("pred_sys", {}))
+
+ return x
+
+ def _set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the pipeline.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = prev_hash
+
+ xy_hash = ""
+ if self.x_sys is not None:
+ self.x_sys.set_hash(self.get_hash())
+ xy_hash += self.x_sys.get_hash()
+ if self.y_sys is not None:
+ self.y_sys.set_hash(self.get_hash())
+ xy_hash += self.y_sys.get_hash()[::-1] # Reversed for edge case where you have two pipelines with the same system but one in x the other in y
+
+ if xy_hash != "":
+ self._hash = hash(xy_hash)
+
+ if self.train_sys is not None:
+ self.train_sys.set_hash(self.get_hash())
+ training_hash = self.train_sys.get_hash()
+ if training_hash != "":
+ self._hash = hash(self._hash + training_hash)
+
+ predlabel_hash = ""
+ if self.pred_sys is not None:
+ self.pred_sys.set_hash(self.get_hash())
+ predlabel_hash += self.pred_sys.get_hash()
+ if self.label_sys is not None:
+ self.label_sys.set_hash(self.get_hash())
+ predlabel_hash += self.label_sys.get_hash()
+
+ if predlabel_hash != "":
+ self._hash = hash(predlabel_hash)
diff --git a/epochlib/pipeline/transforming.py b/epochlib/pipeline/transforming.py
new file mode 100644
index 0000000..529aa26
--- /dev/null
+++ b/epochlib/pipeline/transforming.py
@@ -0,0 +1,209 @@
+"""This module contains the classes for transforming data in the epochlib package."""
+
+import copy
+import warnings
+from abc import abstractmethod
+from typing import Any
+
+from .core import Base, Block, ParallelSystem, SequentialSystem
+
+
+class TransformType(Base):
+ """Abstract transform type describing a class that implements the transform function."""
+
+ @abstractmethod
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ """Transform the input data.
+
+ :param data: The input data.
+ :param transform_args: Keyword arguments.
+ :return: The transformed data.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement transform method.")
+
+
+class Transformer(TransformType, Block):
+ """The transformer block transforms any data it could be x or y data.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ # Transform the input data.
+
+ def get_hash(self) -> str:
+ # Get the hash of the Transformer
+
+ def get_parent(self) -> Any:
+ # Get the parent of the Transformer
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the Transformer
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import Transformer
+
+
+ class MyTransformer(Transformer):
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ # Transform the input data.
+ return data
+
+
+ my_transformer = MyTransformer()
+ transformed_data = my_transformer.transform(data)
+ """
+
+
+class TransformingSystem(TransformType, SequentialSystem):
+ """A system that transforms the input data.
+
+ Parameters:
+ - steps (list[Transformer | TransformingSystem | ParallelTransformingSystem]): The steps in the system.
+
+ Implements the following methods:
+ .. code-block:: python
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ # Transform the input data.
+
+ def get_hash(self) -> str:
+ # Get the hash of the TransformingSystem
+
+ def get_parent(self) -> Any:
+ # Get the parent of the TransformingSystem
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the TransformingSystem
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import TransformingSystem
+
+ transformer_1 = CustomTransformer()
+ transformer_2 = CustomTransformer()
+
+ transforming_system = TransformingSystem(steps=[transformer_1, transformer_2])
+ transformed_data = transforming_system.transform(data)
+ predictions = transforming_system.predict(data)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the TransformingSystem class."""
+ # Assert all steps are a subclass of Transformer
+ for step in self.steps:
+ if not isinstance(step, (TransformType)):
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ super().__post_init__()
+
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ """Transform the input data.
+
+ :param data: The input data.
+ :return: The transformed data.
+ """
+ set_of_steps = set()
+ for step in self.steps:
+ step_name = step.__class__.__name__
+ set_of_steps.add(step_name)
+ if set_of_steps != set(transform_args.keys()):
+ # Raise a warning and print all the keys that do not match
+ warnings.warn(f"The following steps do not exist but were given in the kwargs: {set(transform_args.keys()) - set_of_steps}", stacklevel=2)
+
+ # Loop through each step and call the transform method
+ for step in self.steps:
+ step_name = step.__class__.__name__
+
+ step_args = transform_args.get(step_name, {})
+ if isinstance(step, (TransformType)):
+ data = step.transform(data, **step_args)
+ else:
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ return data
+
+
+class ParallelTransformingSystem(TransformType, ParallelSystem):
+ """A system that transforms the input data in parallel.
+
+ Parameters:
+ - steps (list[Transformer | TransformingSystem | ParallelTransformingSystem]): The steps in the system.
+ - weights (list[float]): Weights of steps in system, if not specified they are all equal.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def concat(self, original_data: Any), data_to_concat: Any, weight: float = 1.0) -> Any:
+ # Specifies how to concat data after parallel computations
+
+ def get_hash(self) -> str:
+ # Get the hash of the ParallelTransformingSystem.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the ParallelTransformingSystem.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the ParallelTransformingSystem
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import ParallelTransformingSystem
+
+ transformer_1 = CustomTransformer()
+ transformer_2 = CustomTransformer()
+
+
+ class CustomParallelTransformingSystem(ParallelTransformingSystem):
+ def concat(self, data1: Any, data2: Any) -> Any:
+ # Concatenate the transformed data.
+ return data1 + data2
+
+
+ transforming_system = CustomParallelTransformingSystem(steps=[transformer_1, transformer_2])
+
+ transformed_data = transforming_system.transform(data)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the ParallelTransformingSystem class."""
+ # Assert all steps are a subclass of Transformer or TransformingSystem
+ for step in self.steps:
+ if not isinstance(step, (TransformType)):
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ super().__post_init__()
+
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ """Transform the input data.
+
+ :param data: The input data.
+ :return: The transformed data.
+ """
+ # Loop through each step and call the transform method
+ out_data = None
+ if len(self.get_steps()) == 0:
+ return data
+
+ for i, step in enumerate(self.get_steps()):
+ step_name = step.__class__.__name__
+
+ step_args = transform_args.get(step_name, {})
+
+ if isinstance(step, (TransformType)):
+ step_data = step.transform(copy.deepcopy(data), **step_args)
+ out_data = self.concat(out_data, step_data, self.get_weights()[i])
+ else:
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ return out_data
diff --git a/epochlib/training/augmentation/image_augmentations.py b/epochlib/training/augmentation/image_augmentations.py
index bf042f7..86c27a1 100644
--- a/epochlib/training/augmentation/image_augmentations.py
+++ b/epochlib/training/augmentation/image_augmentations.py
@@ -6,7 +6,7 @@
import torch
-def get_kornia_mix() -> Any: # noqa: ANN401
+def get_kornia_mix() -> Any:
"""Return kornia mix."""
try:
import kornia
diff --git a/epochlib/training/pretrain_block.py b/epochlib/training/pretrain_block.py
index 487bc1b..3b69a06 100644
--- a/epochlib/training/pretrain_block.py
+++ b/epochlib/training/pretrain_block.py
@@ -54,7 +54,7 @@ def custom_predict(self, x: Any, **pred_args: Any) -> Any:
test_size: float = 0.2
@abstractmethod
- def pretrain_train(self, x: Any, y: Any, train_indices: list[int], *, save_pretrain: bool = True, save_pretrain_with_split: bool = False) -> tuple[Any, Any]: # noqa: ANN401
+ def pretrain_train(self, x: Any, y: Any, train_indices: list[int], *, save_pretrain: bool = True, save_pretrain_with_split: bool = False) -> tuple[Any, Any]:
"""Train pretrain block method.
:param x: The input to the system.
@@ -67,7 +67,7 @@ def pretrain_train(self, x: Any, y: Any, train_indices: list[int], *, save_pretr
f"Train method not implemented for {self.__class__.__name__}",
)
- def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Call the pretrain train method.
:param x: The input to the system.
diff --git a/epochlib/training/torch_trainer.py b/epochlib/training/torch_trainer.py
index ead30bf..a2ba369 100644
--- a/epochlib/training/torch_trainer.py
+++ b/epochlib/training/torch_trainer.py
@@ -409,7 +409,7 @@ def predict_after_train(
case _:
raise ValueError("to_predict should be either 'validation', 'all' or 'none")
- def custom_predict(self, x: Any, **pred_args: Any) -> npt.NDArray[np.float32]: # noqa: ANN401
+ def custom_predict(self, x: Any, **pred_args: Any) -> npt.NDArray[np.float32]:
"""Predict on the validation data.
:param x: The input to the system.
diff --git a/epochlib/training/training.py b/epochlib/training/training.py
index 5ce4569..0b35ba0 100644
--- a/epochlib/training/training.py
+++ b/epochlib/training/training.py
@@ -3,9 +3,8 @@
from dataclasses import dataclass
from typing import Any
-from agogos.training import TrainingSystem, TrainType
-
from epochlib.caching import CacheArgs, Cacher
+from epochlib.pipeline import TrainingSystem, TrainType
@dataclass
@@ -15,7 +14,7 @@ class TrainingPipeline(TrainingSystem, Cacher):
:param steps: The steps to train the model.
"""
- def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]:
"""Train the system.
:param x: The input to the system.
@@ -38,7 +37,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
# Furthest step
for i, step in enumerate(self.get_steps()):
# Check if step is instance of Cacher and if cache_args exists
- if not isinstance(step, Cacher) or not isinstance(step, TrainType):
+ if not isinstance(step, TrainType) or not isinstance(step, Cacher):
self.log_to_debug(f"{step} is not instance of Cacher or TrainType")
continue
@@ -74,7 +73,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
return x, y
- def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any: # noqa: ANN401
+ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any:
"""Predict the output of the system.
:param x: The input to the system.
@@ -92,7 +91,7 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)
# Retrieve furthest step calculated
for i, step in enumerate(self.get_steps()):
# Check if step is instance of Cacher and if cache_args exists
- if not isinstance(step, Cacher) or not isinstance(step, TrainType):
+ if not isinstance(step, TrainType) or not isinstance(step, Cacher):
self.log_to_debug(f"{step} is not instance of Cacher or TrainType")
continue
diff --git a/epochlib/training/training_block.py b/epochlib/training/training_block.py
index 7ce5b0b..c37d556 100644
--- a/epochlib/training/training_block.py
+++ b/epochlib/training/training_block.py
@@ -3,9 +3,8 @@
from abc import abstractmethod
from typing import Any
-from agogos.training import Trainer
-
from epochlib.caching import CacheArgs, Cacher
+from epochlib.pipeline import Trainer
class TrainingBlock(Trainer, Cacher):
@@ -58,7 +57,7 @@ def custom_predict(self, x: Any) -> Any:
x = custom_training_block.predict(x)
"""
- def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]:
"""Train the model.
:param x: The input data.
@@ -92,7 +91,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
return x, y
@abstractmethod
- def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Train the model.
:param x: The input data.
@@ -103,7 +102,7 @@ def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: #
f"Custom transform method not implemented for {self.__class__}",
)
- def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any: # noqa: ANN401
+ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any:
"""Predict using the model.
:param x: The input data.
@@ -129,7 +128,7 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)
return x
@abstractmethod
- def custom_predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401
+ def custom_predict(self, x: Any, **pred_args: Any) -> Any:
"""Predict using the model.
:param x: The input data.
diff --git a/epochlib/training/utils/get_dependencies.py b/epochlib/training/utils/get_dependencies.py
index b4f3239..2bd7273 100644
--- a/epochlib/training/utils/get_dependencies.py
+++ b/epochlib/training/utils/get_dependencies.py
@@ -3,7 +3,7 @@
from typing import Any
-def _get_onnxrt() -> Any: # noqa: ANN401
+def _get_onnxrt() -> Any:
"""Return onnxruntime."""
try:
import onnxruntime as onnxrt
@@ -17,7 +17,7 @@ def _get_onnxrt() -> Any: # noqa: ANN401
return onnxrt
-def _get_openvino() -> Any: # noqa: ANN401
+def _get_openvino() -> Any:
"""Return openvino."""
try:
import openvino
diff --git a/epochlib/transformation/transformation.py b/epochlib/transformation/transformation.py
index 34437e3..a8515e7 100644
--- a/epochlib/transformation/transformation.py
+++ b/epochlib/transformation/transformation.py
@@ -3,9 +3,8 @@
from dataclasses import dataclass
from typing import Any
-from agogos.transforming import TransformingSystem, TransformType
-
from epochlib.caching.cacher import CacheArgs, Cacher
+from epochlib.pipeline import TransformingSystem, TransformType
@dataclass
@@ -60,7 +59,7 @@ def log_to_terminal(self, message: str) -> None:
title: str = "Transformation Pipeline" # The title of the pipeline since transformation pipeline can be used for multiple purposes. (Feature, Label, etc.)
- def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any: # noqa: ANN401
+ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any:
"""Transform the input data.
:param data: The input data.
@@ -81,7 +80,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_
# Furthest step
for i, step in enumerate(self.get_steps()):
# Check if step is instance of Cacher and if cache_args exists
- if not isinstance(step, Cacher) or not isinstance(step, TransformType):
+ if not isinstance(step, TransformType) or not isinstance(step, Cacher):
self.log_to_debug(f"{step} is not instance of Cacher or TransformType")
continue
diff --git a/epochlib/transformation/transformation_block.py b/epochlib/transformation/transformation_block.py
index db817f2..fbe82cf 100644
--- a/epochlib/transformation/transformation_block.py
+++ b/epochlib/transformation/transformation_block.py
@@ -3,9 +3,8 @@
from abc import abstractmethod
from typing import Any
-from agogos.transforming import Transformer
-
from epochlib.caching.cacher import CacheArgs, Cacher
+from epochlib.pipeline import Transformer
class TransformationBlock(Transformer, Cacher):
@@ -55,7 +54,7 @@ def custom_transform(self, data: Any) -> Any:
data = custom_transformation_block.transform(data, cache=cache_args)
"""
- def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any: # noqa: ANN401
+ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any:
"""Transform the input data using a custom method.
:param data: The input data.
@@ -78,7 +77,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_
return data
@abstractmethod
- def custom_transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401
+ def custom_transform(self, data: Any, **transform_args: Any) -> Any:
"""Transform the input data using a custom method.
:param data: The input data.
diff --git a/pyproject.toml b/pyproject.toml
index d32526e..0042461 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,6 @@ classifiers = [
]
dependencies = [
"torch>=2.1.0",
- "agogos>=0.4",
"joblib>=1.4.0",
"annotated-types>=0.6.0",
"typing-extensions>=4.9.0; python_version<'3.12'",
diff --git a/requirements-dev.lock b/requirements-dev.lock
index e428ece..e91e3c4 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -12,8 +12,6 @@
--index-url https://pypi.org/simple/
-e file:.
-agogos==0.4
- # via epochlib
alabaster==0.7.16
# via sphinx
annotated-types==0.7.0
@@ -90,7 +88,6 @@ jinja2==3.1.4
# via sphinx
# via torch
joblib==1.4.2
- # via agogos
# via epochlib
# via librosa
# via scikit-learn
@@ -167,9 +164,9 @@ nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
-nvidia-nccl-cu12==2.20.5
+nvidia-nccl-cu12==2.19.3
# via torch
-nvidia-nvjitlink-cu12==12.5.40
+nvidia-nvjitlink-cu12==12.8.61
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
@@ -289,16 +286,16 @@ tomli==2.0.1
toolz==0.12.1
# via dask
# via partd
-torch==2.3.1
+torch==2.2.2
# via epochlib
# via kornia
# via timm
# via torchvision
-torchvision==0.18.1
+torchvision==0.17.2
# via timm
tqdm==4.66.4
# via huggingface-hub
-triton==2.3.1
+triton==2.2.0
# via torch
typing-extensions==4.12.2
# via epochlib
diff --git a/requirements.lock b/requirements.lock
index c7fbfd3..f20a68e 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -12,8 +12,6 @@
--index-url https://pypi.org/simple/
-e file:.
-agogos==0.4
- # via epochlib
annotated-types==0.7.0
# via epochlib
audiomentations==0.36.0
@@ -63,7 +61,6 @@ importlib-metadata==7.2.1
jinja2==3.1.4
# via torch
joblib==1.4.2
- # via agogos
# via epochlib
# via librosa
# via scikit-learn
@@ -130,9 +127,9 @@ nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
-nvidia-nccl-cu12==2.20.5
+nvidia-nccl-cu12==2.19.3
# via torch
-nvidia-nvjitlink-cu12==12.5.40
+nvidia-nvjitlink-cu12==12.8.61
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
@@ -211,16 +208,16 @@ timm==1.0.7
toolz==0.12.1
# via dask
# via partd
-torch==2.3.1
+torch==2.2.2
# via epochlib
# via kornia
# via timm
# via torchvision
-torchvision==0.18.1
+torchvision==0.17.2
# via timm
tqdm==4.66.4
# via huggingface-hub
-triton==2.3.1
+triton==2.2.0
# via torch
typing-extensions==4.12.2
# via epochlib
diff --git a/ruff.toml b/ruff.toml
index 905d827..b07f196 100644
--- a/ruff.toml
+++ b/ruff.toml
@@ -32,6 +32,7 @@ ignore = [
# flake8-annotations (ANN)
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
+ "ANN401", # Allow Any type in epochlib
# flake8-errmsg (EM)
"EM101", # Exception must not use a string literal, assign to variable first
"EM102", # Exception must not use an f-string literal, assign to variable first
diff --git a/tests/data/test_pipeline_dataset.py b/tests/data/test_pipeline_dataset.py
index 0770151..d18bff4 100644
--- a/tests/data/test_pipeline_dataset.py
+++ b/tests/data/test_pipeline_dataset.py
@@ -6,12 +6,13 @@
import numpy as np
import numpy.typing as npt
from typing import Any
-import torch
+
class TestDataRetrieval(DataRetrieval):
BASE = 2**0
FIRST = 2**1
+
@dataclass
class CustomData(Data):
data1: npt.NDArray[np.int_] | None = None
@@ -19,7 +20,7 @@ class CustomData(Data):
def __post_init__(self) -> None:
self.retrieval = TestDataRetrieval.BASE
-
+
def __getitem__(self, idx: int | npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any] | list[Any]:
"""Get item from the data.
@@ -64,6 +65,7 @@ def __len__(self) -> int:
return len(self.data2)
return 0
+
class TestTrainingBlockNoAug(TrainingBlock):
def train(
@@ -80,6 +82,7 @@ def is_augmentation(self) -> bool:
"""Check if augmentation is enabled."""
return False
+
class TestTrainingBlockWithAug(TrainingBlock):
def train(
@@ -96,8 +99,9 @@ def is_augmentation(self) -> bool:
"""Check if augmentation is enabled."""
return True
+
class TestPipelineDataset(TestCase):
-
+
def test_initialization_errors(self) -> None:
with self.assertRaises(ValueError):
PipelineDataset()
@@ -144,16 +148,16 @@ def test_get_items(self) -> None:
pd_with_data = PipelineDataset(
retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data
)
- self.assertTrue((pd_with_data[[0,1]][0] == [0,1]).all())
+ self.assertTrue((pd_with_data[[0, 1]][0] == [0, 1]).all())
pd_with_indices = PipelineDataset(
- retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data, indices=np.array([0,2])
+ retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data, indices=np.array([0, 2])
)
- self.assertTrue((pd_with_indices[[0,1]][0] == [0,2]).all())
+ self.assertTrue((pd_with_indices[[0, 1]][0] == [0, 2]).all())
pd_no_data = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step])
with self.assertRaises(ValueError):
- pd_no_data[[0,1]]
+ pd_no_data[[0, 1]]
def test_len(self) -> None:
test_data = CustomData()
@@ -165,12 +169,12 @@ def test_len(self) -> None:
)
self.assertTrue(len(pd_with_data) == 3)
- pd_with_data.initialize(x=test_data, y=test_data, indices=np.array([0,2]))
+ pd_with_data.initialize(x=test_data, y=test_data, indices=np.array([0, 2]))
# pd_with_indices = PipelineDataset(
# retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data, indices=np.array([0,2])
# )
self.assertTrue(len(pd_with_data) == 2)
-
+
pd_no_data = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step])
with self.assertRaises(ValueError):
self.assertTrue(len(pd_no_data))
diff --git a/tests/pipeline/test__core.py b/tests/pipeline/test__core.py
new file mode 100644
index 0000000..8e26b22
--- /dev/null
+++ b/tests/pipeline/test__core.py
@@ -0,0 +1,170 @@
+from epochlib.pipeline import Block, Base, SequentialSystem, ParallelSystem
+from tests.pipeline.util import remove_cache_files
+from pathlib import Path
+
+
+class Test_Base:
+ def test_init(self):
+ base = Base()
+ assert base is not None
+
+ def test_set_hash(self):
+ base = Base()
+ prev_hash = base.get_hash()
+ base.set_hash("prev_hash")
+ assert base.get_hash() != prev_hash
+
+ def test_get_children(self):
+ base = Base()
+ assert base.get_children() == []
+
+ def test_get_parent(self):
+ base = Base()
+ assert base.get_parent() is None
+
+ def test__set_parent(self):
+ base = Base()
+ base.set_parent(base)
+ assert base.get_parent() == base
+
+ def test__set_children(self):
+ base = Base()
+ base.set_children([base])
+ assert base.get_children() == [base]
+
+ def test__repr_html_(self):
+ base = Base()
+ assert (
+ base._repr_html_()
+ == "Class: Base
- Hash: a00a595206d7eefcf0e87acf6e2e22ee
- Parent: None
- Children: None
"
+ )
+
+ def test_save_to_html(self):
+ html_path = Path("./tests/cache/test_html.html")
+ Path("./tests/cache/").mkdir(parents=True, exist_ok=True)
+ base = Base()
+ base.save_to_html(html_path)
+ assert Path.exists(html_path)
+ remove_cache_files()
+
+
+class TestBlock:
+ def test_block_init(self):
+ block = Block()
+ assert block is not None
+
+ def test_block_set_hash(self):
+ block = Block()
+ block.set_hash("")
+ hash1 = block.get_hash()
+ assert hash1 != ""
+ block.set_hash(hash1)
+ hash2 = block.get_hash()
+ assert hash2 != ""
+ assert hash1 != hash2
+
+ def test_block_get_hash(self):
+ block = Block()
+ block.set_hash("")
+ hash1 = block.get_hash()
+ assert hash1 != ""
+
+ def test__repr_html_(self):
+ block_instance = Block()
+
+ html_representation = block_instance._repr_html_()
+
+ assert html_representation is not None
+
+
+class TestSequentialSystem:
+ def test_system_init(self):
+ system = SequentialSystem()
+ assert system is not None
+
+ def test_system_hash_no_steps(self):
+ system = SequentialSystem()
+ assert system.get_hash() == ""
+
+ def test_system_hash_with_1_step(self):
+ block1 = Block()
+
+ system = SequentialSystem([block1])
+ assert system.get_hash() != ""
+ assert block1.get_hash() == system.get_hash()
+
+ def test_system_hash_with_2_steps(self):
+ block1 = Block()
+ block2 = Block()
+
+ system = SequentialSystem([block1, block2])
+ assert system.get_hash() != block1.get_hash()
+ assert (
+ system.get_hash() == block2.get_hash() != ""
+ )
+
+ def test_system_hash_with_3_steps(self):
+ block1 = Block()
+ block2 = Block()
+ block3 = Block()
+
+ system = SequentialSystem([block1, block2, block3])
+ assert system.get_hash() != block1.get_hash()
+ assert system.get_hash() != block2.get_hash()
+ assert block1.get_hash() != block2.get_hash()
+ assert (
+ system.get_hash() == block3.get_hash() != ""
+ )
+
+ def test__repr_html_(self):
+ block_instance = Block()
+ system_instance = SequentialSystem([block_instance, block_instance])
+ html_representation = system_instance._repr_html_()
+
+ assert html_representation is not None
+
+
+class TestParallelSystem:
+ def test_parallel_system_init(self):
+ parallel_system = ParallelSystem()
+ assert parallel_system is not None
+
+ def test_parallel_system_hash_no_steps(self):
+ system = ParallelSystem()
+ assert system.get_hash() == ""
+
+ def test_parallel_system_hash_with_1_step(self):
+ block1 = Block()
+
+ system = ParallelSystem([block1])
+ assert system.get_hash() != ""
+ assert block1.get_hash() == system.get_hash()
+
+ def test_parallel_system_hash_with_2_steps(self):
+ block1 = Block()
+ block2 = Block()
+
+ system = ParallelSystem([block1, block2])
+ assert system.get_hash() != block1.get_hash()
+ assert block1.get_hash() == block2.get_hash()
+ assert system.get_hash() != block2.get_hash()
+ assert system.get_hash() != ""
+
+ def test_parallel_system_hash_with_3_steps(self):
+ block1 = Block()
+ block2 = Block()
+ block3 = Block()
+
+ system = ParallelSystem([block1, block2, block3])
+ assert system.get_hash() != block1.get_hash()
+ assert system.get_hash() != block2.get_hash()
+ assert system.get_hash() != block3.get_hash()
+ assert block1.get_hash() == block2.get_hash() == block3.get_hash()
+ assert system.get_hash() != ""
+
+ def test_parallel_system__repr_html_(self):
+ block_instance = Block()
+ system_instance = ParallelSystem([block_instance, block_instance])
+ html_representation = system_instance._repr_html_()
+
+ assert html_representation is not None
diff --git a/tests/pipeline/test_training.py b/tests/pipeline/test_training.py
new file mode 100644
index 0000000..0b7e886
--- /dev/null
+++ b/tests/pipeline/test_training.py
@@ -0,0 +1,614 @@
+import pytest
+import warnings
+from epochlib.pipeline import Trainer, TrainingSystem, ParallelTrainingSystem, Pipeline
+from epochlib.pipeline import Transformer, TransformingSystem
+import numpy as np
+
+
+class TestTrainer:
+ def test_trainer_abstract_train(self):
+ trainer = Trainer()
+ with pytest.raises(NotImplementedError):
+ trainer.train([1, 2, 3], [1, 2, 3])
+
+ def test_trainer_abstract_predict(self):
+ trainer = Trainer()
+ with pytest.raises(NotImplementedError):
+ trainer.predict([1, 2, 3])
+
+ def test_trainer_train(self):
+ class trainerInstance(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ trainer = trainerInstance()
+ assert trainer.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_trainer_predict(self):
+ class trainerInstance(Trainer):
+ def predict(self, x):
+ return x
+
+ trainer = trainerInstance()
+ assert trainer.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_trainer_hash(self):
+ trainer = Trainer()
+ assert trainer.get_hash() != ""
+
+
+class TestTrainingSystem:
+ def test_training_system_init(self):
+ training_system = TrainingSystem()
+ assert training_system is not None
+
+ def test_training_system_init_with_steps(self):
+ class SubTrainer(Trainer):
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem(steps=[block1])
+ assert training_system is not None
+
+ def test_training_system_wrong_step(self):
+ class SubTrainer:
+ def predict(self, x):
+ return x
+
+ with pytest.raises(TypeError):
+ TrainingSystem(steps=[SubTrainer()])
+
+ def test_training_system_steps_changed_predict(self):
+ class SubTrainer:
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem()
+ training_system.steps = [block1]
+ with pytest.raises(TypeError):
+ training_system.predict([1, 2, 3])
+
+ def test_training_system_predict(self):
+ class SubTrainer(Trainer):
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem(steps=[block1])
+ assert training_system.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_trainsys_predict_with_trainer_and_trainsys(self):
+ class SubTrainer(Trainer):
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ block2 = SubTrainer()
+ block3 = TrainingSystem(steps=[block1, block2])
+ assert block2.get_parent() == block3
+ assert block1 in block3.get_children()
+ training_system = TrainingSystem(steps=[block1, block2, block3])
+ assert training_system.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_training_system_train(self):
+ class SubTrainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem(steps=[block1])
+ assert training_system.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_traiinsys_train_with_trainer_and_trainsys(self):
+ class SubTrainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ block1 = SubTrainer()
+ block2 = SubTrainer()
+ block3 = TrainingSystem(steps=[block1, block2])
+ training_system = TrainingSystem(steps=[block1, block2, block3])
+ assert training_system.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_training_system_steps_changed_train(self):
+ class SubTrainer:
+ def train(self, x, y):
+ return x, y
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem()
+ training_system.steps = [block1]
+ with pytest.raises(TypeError):
+ training_system.train([1, 2, 3], [1, 2, 3])
+
+ def test_training_system_empty_hash(self):
+ training_system = TrainingSystem()
+ assert training_system.get_hash() == ""
+
+ def test_training_system_wrong_kwargs(self):
+ class Block1(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ class Block2(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TrainingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "block2": {}}
+ with pytest.warns(
+ UserWarning,
+ match="The following steps do not exist but were given in the kwargs:",
+ ):
+ system.train([1, 2, 3], [1, 2, 3], **kwargs)
+ system.predict([1, 2, 3], **kwargs)
+
+ def test_training_system_right_kwargs(self):
+ class Block1(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ class Block2(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TrainingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "Block2": {}}
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ system.train([1, 2, 3], [1, 2, 3], **kwargs)
+ system.predict([1, 2, 3], **kwargs)
+ assert not caught_warnings
+
+
+class TestParallelTrainingSystem:
+ def test_PTrainSys_init(self):
+ system = ParallelTrainingSystem()
+
+ assert system is not None
+
+ def test_PTrainSys_init_trainers(self):
+ t1 = Trainer()
+ t2 = TrainingSystem()
+
+ system = ParallelTrainingSystem(steps=[t1, t2])
+
+ assert system is not None
+
+ def test_PTrainSys_init_wrong_trainers(self):
+ class WrongTrainer:
+ """Wrong trainer"""
+
+ t1 = WrongTrainer()
+
+ with pytest.raises(TypeError):
+ ParallelTrainingSystem(steps=[t1])
+
+ def test_PTrainSys_train(self):
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+
+ return data1 + data2
+
+ t1 = trainer()
+
+ system = pts(steps=[t1])
+
+ assert system is not None
+ assert system.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_PTrainSys_trainers(self):
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = trainer()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ assert system.train([1, 2, 3], [1, 2, 3]) == (
+ [1, 2, 3, 1, 2, 3],
+ [1, 2, 3, 1, 2, 3],
+ )
+
+ def test_PTrainSys_trainers_with_weights(self):
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class trainer2(Trainer):
+ def train(self, x, y):
+ return x * 3, y
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2 * weight
+ return data1 + data2 * weight
+
+ t1 = trainer()
+ t2 = trainer2()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ test = np.array([1, 2, 3])
+ preds, labels = system.train(test, test)
+ assert np.array_equal(preds, test * 2)
+ assert np.array_equal(labels, test)
+
+ def test_PTrainSys_predict(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+
+ system = pts(steps=[t1])
+
+ assert system is not None
+ assert system.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_PTrainSys_predict_with_trainsys(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = TrainingSystem(steps=[t1])
+
+ system = pts(steps=[t2, t1])
+
+ assert system is not None
+ assert system.predict([1, 2, 3]) == [1, 2, 3, 1, 2, 3]
+
+ def test_PTrainSys_predict_with_trainer_and_trainsys(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = trainer()
+ t3 = TrainingSystem(steps=[t1, t2])
+
+ system = pts(steps=[t1, t2, t3])
+
+ assert system is not None
+ assert t3.predict([1, 2, 3]) == [1, 2, 3]
+ assert system.predict([1, 2, 3]) == [1, 2, 3, 1, 2, 3, 1, 2, 3]
+
+ def test_PTrainSys_predictors(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = trainer()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ assert system.predict([1, 2, 3]) == [1, 2, 3, 1, 2, 3]
+
+ def test_PTrainSys_concat_labels_throws_error(self):
+ system = ParallelTrainingSystem()
+
+ with pytest.raises(NotImplementedError):
+ system.concat_labels([1, 2, 3], [4, 5, 6])
+
+ def test_PTrainSys_step_1_changed(self):
+ system = ParallelTrainingSystem()
+
+ t1 = Transformer()
+ system.steps = [t1]
+
+ with pytest.raises(TypeError):
+ system.train([1, 2, 3], [1, 2, 3])
+
+ with pytest.raises(TypeError):
+ system.predict([1, 2, 3])
+
+ def test_PTrainSys_step_2_changed(self):
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+
+ return data1 + data2
+
+ system = pts()
+
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ def predict(self, x):
+ return x
+
+ t1 = trainer()
+ t2 = Transformer()
+ system.steps = [t1, t2]
+
+ with pytest.raises(TypeError):
+ system.train([1, 2, 3], [1, 2, 3])
+
+ with pytest.raises(TypeError):
+ system.predict([1, 2, 3])
+
+ def test_train_parallel_hashes(self):
+ class SubTrainer1(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class SubTrainer2(Trainer):
+ def train(self, x, y):
+ return x * 2, y
+
+ block1 = SubTrainer1()
+ block2 = SubTrainer2()
+
+ system1 = ParallelTrainingSystem(steps=[block1, block2])
+ system1_copy = ParallelTrainingSystem(steps=[block1, block2])
+ system2 = ParallelTrainingSystem(steps=[block2, block1])
+ system2_copy = ParallelTrainingSystem(steps=[block2, block1])
+
+ assert system1.get_hash() == system2.get_hash()
+ assert system1.get_hash() == system1_copy.get_hash()
+ assert system2.get_hash() == system2_copy.get_hash()
+
+
+class TestPipeline:
+ def test_pipeline_init(self):
+ pipeline = Pipeline()
+ assert pipeline is not None
+
+ def test_pipeline_init_with_systems(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ label_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ label_sys=label_system,
+ )
+ assert pipeline is not None
+
+ def test_pipeline_train(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ label_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ label_sys=label_system,
+ )
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_no_y_system(self):
+ x_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_no_x_system(self):
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_no_train_system(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ post_system = TransformingSystem()
+ post_label_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=None,
+ pred_sys=post_system,
+ label_sys=post_label_system,
+ )
+ assert pipeline.train([1, 2], [1, 2]) == ([1, 2], [1, 2])
+
+ def test_pipeline_train_no_refining_system(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ pipeline = Pipeline(x_sys=x_system, y_sys=y_system, train_sys=training_system)
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_1_x_transform_block(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem(steps=[transform1])
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ result = pipeline.train(np.array([1, 2, 3]), [1, 2, 3])
+ assert np.array_equal(result[0], np.array([2, 4, 6])) and np.array_equal(
+ result[1], np.array([1, 2, 3])
+ )
+
+ def test_pipeline_predict(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_pipeline_predict_no_y_system(self):
+ x_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_pipeline_predict_no_systems(self):
+ pipeline = Pipeline()
+ assert pipeline.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_pipeline_get_hash_no_change(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ predicting_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=predicting_system,
+ )
+ assert x_system.get_hash() == ""
+ # assert y_system.get_hash() == ""
+ # assert training_system.get_hash() == ""
+ # assert predicting_system.get_hash() == ""
+ # assert pipeline.get_hash() == ""
+
+ def test_pipeline_get_hash_with_change(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem(steps=[transform1])
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert x_system.get_hash() != y_system.get_hash()
+ assert pipeline.get_hash() != ""
+
+ def test_pipeline_predict_system_hash(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem(steps=[transform1])
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert prediction_system.get_hash() != x_system.get_hash()
+ assert pipeline.get_hash() != ""
+
+ def test_pipeline_pre_post_hash(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem(steps=[transform1])
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem(steps=[transform1])
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert x_system.get_hash() == prediction_system.get_hash()
+ assert pipeline.get_hash() != ""
diff --git a/tests/pipeline/test_transforming.py b/tests/pipeline/test_transforming.py
new file mode 100644
index 0000000..394d900
--- /dev/null
+++ b/tests/pipeline/test_transforming.py
@@ -0,0 +1,321 @@
+import warnings
+import numpy as np
+import pytest
+
+from epochlib.pipeline import Trainer
+from epochlib.pipeline import (
+ Transformer,
+ TransformingSystem,
+ ParallelTransformingSystem,
+)
+
+
+class TestTransformer:
+ def test_transformer_abstract(self):
+ transformer = Transformer()
+
+ with pytest.raises(NotImplementedError):
+ transformer.transform([1, 2, 3])
+
+ def test_transformer_transform(self):
+ class transformerInstance(Transformer):
+ def transform(self, data):
+ return data
+
+ transformer = transformerInstance()
+
+ assert transformer.transform([1, 2, 3]) == [1, 2, 3]
+
+ def test_transformer_hash(self):
+ transformer = Transformer()
+ assert transformer.get_hash() == "1cbcc4f2d0921b050d9b719d2beb6529"
+
+
+class TestTransformingSystem:
+ def test_transforming_system_init(self):
+ transforming_system = TransformingSystem()
+ assert transforming_system is not None
+
+ def test_transforming_system_init_with_steps(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ assert transforming_system is not None
+
+ def test_transforming_system_wrong_step(self):
+ class SubTransformer:
+ def transform(self, x):
+ return x
+
+ with pytest.raises(TypeError):
+ TransformingSystem(steps=[SubTransformer()])
+
+ def test_transforming_system_steps_changed(self):
+ class SubTransformer:
+ def transform(self, x):
+ return x
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem()
+ transforming_system.steps = [block1]
+ with pytest.raises(TypeError):
+ transforming_system.transform([1, 2, 3])
+
+ def test_transforming_system_transform_1_block(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ assert transforming_system.transform([1, 2, 3]) == [1, 2, 3]
+
+ def test_transforming_system_transform_1_block_with_args(self):
+ class SubTransformer(Transformer):
+ def transform(self, data):
+ return data
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ assert transforming_system.transform([1, 2, 3], **{"SubTransformer": {}}) == [
+ 1,
+ 2,
+ 3,
+ ]
+
+ def test_transforming_system_transform_2_blocks(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ block1 = SubTransformer()
+ block2 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1, block2])
+ result = transforming_system.transform(np.array([1, 2, 3]))
+ assert np.array_equal(result, np.array([4, 8, 12]))
+
+ def test_transformsys_with_transformsys(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ block1 = SubTransformer()
+ block2 = TransformingSystem(steps=[block1])
+ transforming_system = TransformingSystem(steps=[block2])
+ result = transforming_system.transform(np.array([1, 2, 3]))
+ assert np.array_equal(result, np.array([2, 4, 6]))
+
+ def test_transforming_system_transform_with_args(self):
+ class SubTransformer(Transformer):
+ def transform(self, data, multiplier=2):
+ return data * multiplier
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ result = transforming_system.transform(
+ np.array([1, 2, 3]), **{"SubTransformer": {"multiplier": 2}}
+ )
+ assert np.array_equal(result, np.array([2, 4, 6]))
+
+ def test_transforming_system_transform_with_args_2_blocks(self):
+ class SubTransformer(Transformer):
+ def transform(self, data, multiplier=2):
+ return data * multiplier
+
+ block1 = SubTransformer()
+ block2 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1, block2])
+ result = transforming_system.transform(
+ np.array([1, 2, 3]), **{"SubTransformer": {"multiplier": 2}}
+ )
+ assert np.array_equal(result, np.array([4, 8, 12]))
+
+ def test_transforming_system_transform_with_recursive_args(self):
+ class SubTransformer(Transformer):
+ def transform(self, data, multiplier=2):
+ return data * multiplier
+
+ block1 = SubTransformer()
+ block2 = SubTransformer()
+ block3 = TransformingSystem(steps=[block2])
+ block4 = TransformingSystem(steps=[block3])
+ transforming_system = TransformingSystem(steps=[block1, block4])
+ assert np.array_equal(
+ transforming_system.transform(
+ np.array([1, 2, 3]), **{"SubTransformer": {"multiplier": 2}}
+ ),
+ np.array([4, 8, 12]),
+ )
+ assert np.array_equal(
+ transforming_system.transform(
+ np.array([1, 2, 3]),
+ **{
+ "TransformingSystem": {
+ "TransformingSystem": {"SubTransformer": {"multiplier": 3}}
+ }
+ },
+ ),
+ np.array([6, 12, 18]),
+ )
+
+ def test_transforming_system_empty_hash(self):
+ transforming_system = TransformingSystem()
+ assert transforming_system.get_hash() == ""
+
+ def test_transforming_system_wrong_kwargs(self):
+ class Block1(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ class Block2(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TransformingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "block2": {}}
+ with pytest.warns(
+ UserWarning,
+ match="The following steps do not exist but were given in the kwargs:",
+ ):
+ system.transform([1, 2, 3], **kwargs)
+
+ def test_transforming_system_right_kwargs(self):
+ class Block1(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ class Block2(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TransformingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "Block2": {}}
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ system.transform([1, 2, 3], **kwargs)
+
+ assert not caught_warnings
+
+
+class TestParallelTransformingSystem:
+ def test_parallel_transforming_system(self):
+ # Create an instance of the system
+ system = ParallelTransformingSystem()
+
+ # Assert the system is an instance of ParallelTransformingSystem
+ assert isinstance(system, ParallelTransformingSystem)
+ assert system is not None
+
+ def test_parallel_transforming_system_wrong_step(self):
+ class SubTransformer:
+ def transform(self, x):
+ return x
+
+ with pytest.raises(TypeError):
+ ParallelTransformingSystem(steps=[SubTransformer()])
+
+ def test_parallel_transforming_system_transformers(self):
+ transformer1 = Transformer()
+ transformer2 = TransformingSystem()
+
+ system = ParallelTransformingSystem(steps=[transformer1, transformer2])
+ assert system is not None
+
+ def test_parallel_transforming_system_transform(self):
+ class transformer(Transformer):
+ def transform(self, data):
+ return data
+
+ class pts(ParallelTransformingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = transformer()
+
+ system = pts(steps=[t1])
+
+ assert system is not None
+ assert system.transform([1, 2, 3]) == [1, 2, 3]
+
+ def test_pts_transformers_transform(self):
+ class transformer(Transformer):
+ def transform(self, data):
+ return data
+
+ class pts(ParallelTransformingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = transformer()
+ t2 = transformer()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ assert system.transform([1, 2, 3]) == [1, 2, 3, 1, 2, 3]
+
+ def test_parallel_transforming_system_concat_throws_error(self):
+ system = ParallelTransformingSystem()
+
+ with pytest.raises(NotImplementedError):
+ system.concat([1, 2, 3], [4, 5, 6])
+
+ def test_pts_step_1_changed(self):
+ system = ParallelTransformingSystem()
+
+ t1 = Trainer()
+ system.steps = [t1]
+
+ with pytest.raises(TypeError):
+ system.transform([1, 2, 3])
+
+ def test_pts_step_2_changed(self):
+ class pts(ParallelTransformingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ system = pts()
+
+ class transformer(Transformer):
+ def transform(self, data):
+ return data
+
+ t1 = transformer()
+ t2 = Trainer()
+ system.steps = [t1, t2]
+
+ with pytest.raises(TypeError):
+ system.transform([1, 2, 3])
+
+ def test_transform_parallel_hashes(self):
+ class SubTransformer1(Transformer):
+ def transform(self, x):
+ return x
+
+ class SubTransformer2(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ block1 = SubTransformer1()
+ block2 = SubTransformer2()
+
+ system1 = ParallelTransformingSystem(steps=[block1, block2])
+ system1_copy = ParallelTransformingSystem(steps=[block1, block2])
+ system2 = ParallelTransformingSystem(steps=[block2, block1])
+ system2_copy = ParallelTransformingSystem(steps=[block2, block1])
+
+ assert system1.get_hash() == system2.get_hash()
+ assert system1.get_hash() == system1_copy.get_hash()
+ assert system2.get_hash() == system2_copy.get_hash()
diff --git a/tests/pipeline/util.py b/tests/pipeline/util.py
new file mode 100644
index 0000000..6eb0680
--- /dev/null
+++ b/tests/pipeline/util.py
@@ -0,0 +1,11 @@
+import glob
+import os
+
+
+def remove_cache_files():
+ files = glob.glob("tests/cache/*")
+ for f in files:
+ # If f is readme.md, skip it
+ if "README.md" in f:
+ continue
+ os.remove(f)
diff --git a/tests/training/test_training.py b/tests/training/test_training.py
index ce6df00..d8ba215 100644
--- a/tests/training/test_training.py
+++ b/tests/training/test_training.py
@@ -1,6 +1,6 @@
import numpy as np
import pytest
-from agogos.training import Trainer
+from epochlib.pipeline import Trainer
from epochlib.training import TrainingPipeline
from epochlib.training import TrainingBlock
diff --git a/tests/transformation/test_transformation.py b/tests/transformation/test_transformation.py
index df76240..cb2d14b 100644
--- a/tests/transformation/test_transformation.py
+++ b/tests/transformation/test_transformation.py
@@ -2,7 +2,7 @@
import numpy as np
import pytest
-from agogos.transforming import Transformer
+from epochlib.pipeline import Transformer
from epochlib.transformation import TransformationPipeline
from epochlib.transformation import TransformationBlock