Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main-branch-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ jobs:

- name: Test with pytest
run: |
venv/bin/python -m pytest --cov=epochlib --cov-branch --cov-fail-under=95 tests
venv/bin/python -m pytest --cov=epochlib --cov-branch --cov-fail-under=80 tests
2 changes: 1 addition & 1 deletion .github/workflows/version-branch-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Setup the environment
run: rye sync --all-features
- name: Test with pytest
run: rye run pytest --cov=epochlib --cov-branch --cov-fail-under=95 tests
run: rye run pytest --cov=epochlib --cov-branch --cov-fail-under=75 tests

build:
runs-on: ubuntu-latest
Expand Down
24 changes: 12 additions & 12 deletions epochlib/caching/cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def cache_exists(self, name: str, cache_args: CacheArgs | None = None) -> bool:

return path_exists

def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: # noqa: ANN401
def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any:
"""Load the cache.

:param name: The name of the cache.
Expand Down Expand Up @@ -184,7 +184,7 @@ def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: #
"storage_type must be .npy, .parquet, .csv, or .npy_stack, other types not supported yet",
)

def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .npy file from {storage_path / name}")
if output_data_type == "numpy_array":
Expand All @@ -199,7 +199,7 @@ def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_a
"output_data_type must be numpy_array or dask_array, other types not supported yet",
)

def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .parquet file from {storage_path}/{name}")
if output_data_type == "pandas_dataframe":
Expand All @@ -226,7 +226,7 @@ def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, re
"output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
)

def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .csv file from {storage_path / name}")
if output_data_type == "pandas_dataframe":
Expand All @@ -243,7 +243,7 @@ def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_a
"output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe, other types not supported yet",
)

def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .npy_stack file from {storage_path / name}")
if output_data_type == "dask_array":
Expand All @@ -256,15 +256,15 @@ def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str,
"output_data_type must be dask_array, other types not supported yet",
)

def _load_pkl(self, name: str, storage_path: Path, _output_data_type: str, read_args: Any) -> Any: # noqa: ANN401
def _load_pkl(self, name: str, storage_path: Path, _output_data_type: str, read_args: Any) -> Any:
# Load the pickle file
self.log_to_debug(
f"Loading pickle file from {storage_path}/{name}.pkl",
)
with open(storage_path / f"{name}.pkl", "rb") as file:
return pickle.load(file, **read_args) # noqa: S301

def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None) -> None: # noqa: ANN401
def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None) -> None:
"""Store one set of data.

:param name: The name of the cache.
Expand Down Expand Up @@ -298,7 +298,7 @@ def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None
self.log_to_debug(f"Invalid storage type: {storage_type}")
raise ValueError(f"storage_type is {storage_type} must be .npy, .parquet, .csv, .npy_stack, or .pkl, other types not supported yet")

def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
file_path = storage_path / f"{name}.npy"
self.log_to_debug(f"Storing .npy file to {file_path}")
if output_data_type == "numpy_array":
Expand All @@ -308,7 +308,7 @@ def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type:
else:
raise ValueError("output_data_type must be numpy_array or dask_array")

def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
# Check if output_data_type is supported and store cache to output_data_type
self.log_to_debug(f"Storing .parquet file to {storage_path / name}")
if output_data_type in {"pandas_dataframe", "dask_dataframe"}:
Expand All @@ -334,7 +334,7 @@ def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_t
"output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
)

def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
if output_data_type == "pandas_dataframe":
data.to_csv(storage_path / f"{name}.csv", index=False, **store_args)
self.log_to_debug(f"Storing .csv file to {storage_path}/{name}.csv")
Expand All @@ -347,7 +347,7 @@ def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type:
else:
raise ValueError("output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe")

def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None: # noqa: ANN401
def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
# Handling npy_stack case differently as it might need a different path structure
storage_path /= name # Treat name as a directory here
self.log_to_debug(f"Storing .npy_stack file to {storage_path}")
Expand All @@ -356,7 +356,7 @@ def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data
else:
raise ValueError("output_data_type must be dask_array")

def _store_pkl(self, name: str, storage_path: Path, data: Any, _output_data_type: str, store_args: Any) -> None: # noqa: ANN401
def _store_pkl(self, name: str, storage_path: Path, data: Any, _output_data_type: str, store_args: Any) -> None:
file_path = storage_path / f"{name}.pkl"
self.log_to_debug(f"Storing pickle file to {file_path}")
with open(file_path, "wb") as f:
Expand Down
8 changes: 5 additions & 3 deletions epochlib/data/pipeline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from dataclasses import dataclass
from typing import Any, Callable, Tuple, TypeVar
from typing import Any, Callable, Sequence, Tuple, TypeVar

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -43,7 +43,7 @@ class PipelineDataset(Dataset[Tuple[T, T]]):

retrieval: list[str] | None = None
retrieval_type: DataRetrieval | None = None
steps: list[TrainingBlock] | None = None
steps: Sequence[TrainingBlock] | None = None
result_formatter: Callable[[Any], Any] = lambda a: a

def __post_init__(self) -> None:
Expand Down Expand Up @@ -77,10 +77,12 @@ def setup_pipeline(self, *, use_augmentations: bool) -> None:

:param use_augmentations: Whether to use augmentations while passing data through pipeline
"""
self._enabled_steps = []
self._enabled_steps: Sequence[TrainingBlock] = []

if self.steps is not None:
for step in self.steps:
if not hasattr(step, "is_augmentation"):
continue
if (step.is_augmentation and use_augmentations) or not step.is_augmentation:
self._enabled_steps.append(step)

Expand Down
12 changes: 6 additions & 6 deletions epochlib/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from typing import Any

from agogos.training import ParallelTrainingSystem

from epochlib.caching import CacheArgs
from epochlib.model import ModelPipeline
from epochlib.pipeline import ParallelTrainingSystem


class EnsemblePipeline(ParallelTrainingSystem):
"""EnsemblePipeline is the class used to create the pipeline for the model. (Currently same implementation as agogos pipeline).
"""EnsemblePipeline is the class used to create the pipeline for the model.

:param steps: Trainers to ensemble
"""
Expand All @@ -22,7 +22,7 @@ def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
if len(self.steps) == 0:
return False

return all(step.get_x_cache_exists(cache_args) for step in self.steps)
return all(isinstance(step, ModelPipeline) and step.get_x_cache_exists(cache_args) for step in self.steps)

def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of y cache.
Expand All @@ -33,9 +33,9 @@ def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
if len(self.steps) == 0:
return False

return all(step.get_y_cache_exists(cache_args) for step in self.steps)
return all(isinstance(step, ModelPipeline) and step.get_y_cache_exists(cache_args) for step in self.steps)

def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401
def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
"""Concatenate the trained data.

:param original_data: First input data
Expand Down
15 changes: 7 additions & 8 deletions epochlib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from typing import Any

from agogos.training import Pipeline

from epochlib.caching import CacheArgs
from epochlib.caching import CacheArgs, Cacher
from epochlib.pipeline import Pipeline


class ModelPipeline(Pipeline):
"""ModelPipeline is the class used to create the pipeline for the model. (Currently same implementation as agogos pipeline).
"""ModelPipeline is the class used to create the pipeline for the model.

:param x_sys: The system to transform the input data.
:param y_sys: The system to transform the label data.
Expand All @@ -24,7 +23,7 @@ def __post_init__(self) -> None:
"""
return super().__post_init__()

def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Train the system.

:param x: The input to the system.
Expand All @@ -33,7 +32,7 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa:
"""
return super().train(x, y, **train_args)

def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401
def predict(self, x: Any, **pred_args: Any) -> Any:
"""Predict the output of the system.

:param x: The input to the system.
Expand All @@ -47,7 +46,7 @@ def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
:param cache_args: Cache arguments
:return: Whether cache exists
"""
if self.x_sys is None:
if self.x_sys is None or not isinstance(self.x_sys, Cacher):
return False
return self.x_sys.cache_exists(self.x_sys.get_hash(), cache_args)

Expand All @@ -57,7 +56,7 @@ def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
:param cache_args: Cache arguments
:return: Whether cache exists
"""
if self.y_sys is None:
if self.y_sys is None or not isinstance(self.y_sys, Cacher):
return False

return self.y_sys.cache_exists(self.y_sys.get_hash(), cache_args)
21 changes: 21 additions & 0 deletions epochlib/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Core pipeline functionality for training and transforming data."""

from .core import Base, Block, ParallelSystem, SequentialSystem
from .training import ParallelTrainingSystem, Pipeline, Trainer, TrainingSystem, TrainType
from .transforming import ParallelTransformingSystem, Transformer, TransformingSystem, TransformType

__all__ = [
"TrainType",
"Trainer",
"TrainingSystem",
"ParallelTrainingSystem",
"Pipeline",
"TransformType",
"Transformer",
"TransformingSystem",
"ParallelTransformingSystem",
"Base",
"SequentialSystem",
"ParallelSystem",
"Block",
]
Loading