Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions epochlib/caching/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Caching

Caching is an important aspect of machine learning competitions, whether it is speeding up your training, development or being able to recover from failures there are many use cases. In this folder an interface is defined for what methods a cacher should implement such that it can be injected into other objects. Cachers should output the same data type when reading the data as that they received when storing. In the name a cacher should therefore specify two things: the storage type like parquet and the output data type like a numpy array.

Initially this was implemented to be used with inheritance as then it was thought to be easier to have the access within the class. However, this hides away a lot of information and is not very usable. By ignoring the field from the hash in a dataclass it can be used without affecting the hash.

## Available cachers

A list of the available cachers is provided below:
.npy:
- Numpy Array
- Dask Array

.parquet:
- Pandas Dataframe
- Dask Dataframe
- Numpy Array
- Dask Array
- Polars Dataframe

.csv:
- Pandas Dataframe
- Dask Dataframe
- Polars Dataframe

.npy_stack:
- Dask Array

.pkl:
- Any Object
31 changes: 30 additions & 1 deletion epochlib/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
"""Caching module for epochlib."""

from .cacher import CacheArgs, Cacher
from .any_to_pkl_cacher import AnyToPklCacher
from .cacher_interface import CacherInterface
from .dask_array_to_npy_cacher import DaskArrayToNpyCacher
from .dask_array_to_npy_stack_cacher import DaskArrayToNpyStackCacher
from .dask_array_to_parquet_cacher import DaskArrayToParquetCacher
from .dask_dataframe_to_csv_cacher import DaskDataFrameToCSVCacher
from .dask_dataframe_to_parquet_cacher import DaskDataFrameToParquetCacher
from .numpy_array_to_npy_cacher import NumpyArrayToNpyCacher
from .numpy_array_to_parquet_cacher import NumpyArrayToParquetCacher
from .pandas_dataframe_to_csv_cacher import PandasDataFrameToCSVCacher
from .pandas_dataframe_to_parquet_cacher import PandasDataFrameToParquetCacher
from .polars_dataframe_to_csv_cacher import PolarsDataFrameToCSVCacher
from .polars_dataframe_to_parquet_cacher import PolarsDataFrameToParquetCacher

__all__ = ["Cacher", "CacheArgs"]
__all__ = [
"Cacher",
"CacheArgs",
"CacherInterface",
"NumpyArrayToNpyCacher",
"DaskArrayToNpyCacher",
"AnyToPklCacher",
"DaskArrayToNpyStackCacher",
"DaskArrayToParquetCacher",
"DaskDataFrameToCSVCacher",
"DaskDataFrameToParquetCacher",
"NumpyArrayToParquetCacher",
"PandasDataFrameToCSVCacher",
"PandasDataFrameToParquetCacher",
"PolarsDataFrameToCSVCacher",
"PolarsDataFrameToParquetCacher",
]
50 changes: 50 additions & 0 deletions epochlib/caching/any_to_pkl_cacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""This module contains the AnyToPklCacher class."""

import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

from .cacher_interface import CacherInterface


@dataclass
class AnyToPklCacher(CacherInterface):
"""The any to .pkl cacher.

:param storage_path: The path to store the cache files.
:param read_args: The arguments to read the cache, pickle.load() extra args.
:param store_args: The arguments to store the cache, pickle.dump() extra args.
"""

storage_path: str
read_args: dict[str, Any] = field(default_factory=dict)
store_args: dict[str, Any] = field(default_factory=dict)

def cache_exists(self, name: str) -> bool:
"""Check if a cache exists.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)

return os.path.exists(storage_path / f"{name}.pkl")

def load_cache(self, name: str) -> Any:
"""Load a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
with open(Path(self.storage_path) / f"{name}.pkl", "rb") as file:
return pickle.load(file, **self.read_args) # noqa: S301

def store_cache(self, name: str, data: Any) -> None:
"""Store a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)
storage_path.mkdir(parents=True, exist_ok=True)
with open(storage_path / f"{name}.pkl", "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL, **self.store_args)
40 changes: 40 additions & 0 deletions epochlib/caching/cacher_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The cache interface defines the methods that a cacher must implement."""

from typing import Any


class CacherInterface:
"""The cache interface defines the methods that a cacher must implement.

Methods
-------
.. code-block:: python
def cache_exists(name: str) -> bool

def load_cache(name: str) -> Any

def store_cache(name: str, data: Any) -> None
"""

def cache_exists(self, name: str) -> bool:
"""Check if a cache exists.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
raise NotImplementedError(f"cache_exists from CacheInterface not implemented in {self.__class__.__name__}.")

def load_cache(self, name: str) -> Any:
"""Load a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
raise NotImplementedError(f"load_cache from CacheInterface not implemented in {self.__class__.__name__}.")

def store_cache(self, name: str, data: Any) -> None:
"""Store a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
:param data: The data to store in the cache.
:param store_args: Any additional arguments to store the
"""
raise NotImplementedError(f"store_cache from CacheInterface not implemented in {self.__class__.__name__}.")
65 changes: 65 additions & 0 deletions epochlib/caching/dask_array_to_npy_cacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""This module contains the DaskArrayToNpyCacher class."""

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import numpy as np

from .cacher_interface import CacherInterface

try:
import dask.array as da
except ImportError:
"""User doesn't require these packages"""


@dataclass
class DaskArrayToNpyCacher(CacherInterface):
"""The dask array to .npy cacher.

:param storage_path: The path to store the cache files.
:param read_args: The arguments to read the cache, da.from_array() extra args.
:param store_args: The arguments to store the cache, np.save() extra args.

Methods
-------
.. code-block:: python
def cache_exists(name: str) -> bool

def load_cache(name: str) -> Any

def store_cache(name: str, data: Any) -> None
"""

storage_path: str
read_args: dict[str, Any] = field(default_factory=dict)
store_args: dict[str, Any] = field(default_factory=dict)

def cache_exists(self, name: str) -> bool:
"""Check if a cache exists.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)

return os.path.exists(storage_path / f"{name}.npy")

def load_cache(self, name: str) -> da.Array:
"""Load a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
:param cache_args: The cache arguments.
"""
return da.from_array(np.load(Path(self.storage_path) / f"{name}.npy"), **self.read_args)

def store_cache(self, name: str, data: da.Array) -> None:
"""Store a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
:param cache_args: The cache arguments.
"""
storage_path = Path(self.storage_path)
storage_path.mkdir(parents=True, exist_ok=True)
np.save(storage_path / f"{name}.npy", data.compute(), **self.store_args)
49 changes: 49 additions & 0 deletions epochlib/caching/dask_array_to_npy_stack_cacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""This module contains the DaskArrayToNpyStackCacher class."""

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import dask.array as da

from .cacher_interface import CacherInterface


@dataclass
class DaskArrayToNpyStackCacher(CacherInterface):
"""The dask array to .npy_stack cacher.

:param storage_path: The path to store the cache files.
:param read_args: The arguments to read the cache, da.from_npy_stack() extra args.
:param store_args: The arguments to store the cache, da.to_npy_stack() extra args.
"""

storage_path: str
read_args: dict[str, Any] = field(default_factory=dict)
store_args: dict[str, Any] = field(default_factory=dict)

def cache_exists(self, name: str) -> bool:
"""Check if a cache exists.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)

return os.path.exists(storage_path / name)

def load_cache(self, name: str) -> da.Array:
"""Load a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
return da.from_npy_stack(Path(self.storage_path) / name, **self.read_args)

def store_cache(self, name: str, data: da.Array) -> None:
"""Store a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)
storage_path.mkdir(parents=True, exist_ok=True)
da.to_npy_stack(storage_path / name, data, **self.store_args)
54 changes: 54 additions & 0 deletions epochlib/caching/dask_array_to_parquet_cacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""This module contains the DaskArrayToParquetCacher class."""

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import dask.array as da
import dask.dataframe as dd

from .cacher_interface import CacherInterface


@dataclass
class DaskArrayToParquetCacher(CacherInterface):
"""The dask array to .parquet cacher.

:param storage_path: The path to store the cache files.
:param read_args: The arguments to read the cache, dd.read_parquet() extra args.
:param store_args: The arguments to store the cache, dd.to_parquet() extra args.
"""

storage_path: str
read_args: dict[str, Any] = field(default_factory=dict)
store_args: dict[str, Any] = field(default_factory=dict)

def cache_exists(self, name: str) -> bool:
"""Check if a cache exists.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)

return os.path.exists(storage_path / f"{name}.parquet")

def load_cache(self, name: str) -> da.Array:
"""Load a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
return dd.read_parquet(Path(self.storage_path) / f"{name}.parquet", **self.read_args).to_dask_array()

def store_cache(self, name: str, data: da.Array) -> None:
"""Store a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)
storage_path.mkdir(parents=True, exist_ok=True)
new_dd = dd.from_dask_array(data)
new_dd = new_dd.rename(
columns={col: str(col) for col in new_dd.columns},
)
new_dd.to_parquet(storage_path / f"{name}.parquet", **self.store_args)
50 changes: 50 additions & 0 deletions epochlib/caching/dask_dataframe_to_csv_cacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""This module contains the DaskDataFrameToCSVCacher class."""

import glob
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import dask.dataframe as dd

from .cacher_interface import CacherInterface


@dataclass
class DaskDataFrameToCSVCacher(CacherInterface):
"""The dask dataframe to .csv cacher.

:param storage_path: The path to store the cache files.
:param read_args: The arguments to read the cache, dd.read_csv() extra args.
:param store_args: The arguments to store the cache, dd.to_csv() extra args.
"""

storage_path: str
read_args: dict[str, Any] = field(default_factory=dict)
store_args: dict[str, Any] = field(default_factory=dict)

def cache_exists(self, name: str) -> bool:
"""Check if a cache exists.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)

return os.path.exists(storage_path / f"{name}.csv") or glob.glob(str(storage_path / name / "*.part")) != []

def load_cache(self, name: str) -> dd.DataFrame:
"""Load a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
return dd.read_csv(Path(self.storage_path) / name / "*.part", **self.read_args)

def store_cache(self, name: str, data: dd.DataFrame) -> None:
"""Store a cache.

:param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_].
"""
storage_path = Path(self.storage_path)
storage_path.mkdir(parents=True, exist_ok=True)
data.to_csv(storage_path / name, index=False, **self.store_args)
Loading
Loading