diff --git a/epochlib/caching/README.md b/epochlib/caching/README.md new file mode 100644 index 0000000..97745b2 --- /dev/null +++ b/epochlib/caching/README.md @@ -0,0 +1,30 @@ +# Caching + +Caching is an important aspect of machine learning competitions, whether it is speeding up your training, development or being able to recover from failures there are many use cases. In this folder an interface is defined for what methods a cacher should implement such that it can be injected into other objects. Cachers should output the same data type when reading the data as that they received when storing. In the name a cacher should therefore specify two things: the storage type like parquet and the output data type like a numpy array. + +Initially this was implemented to be used with inheritance as then it was thought to be easier to have the access within the class. However, this hides away a lot of information and is not very usable. By ignoring the field from the hash in a dataclass it can be used without affecting the hash. + +## Available cachers + +A list of the available cachers is provided below: +.npy: +- Numpy Array +- Dask Array + +.parquet: +- Pandas Dataframe +- Dask Dataframe +- Numpy Array +- Dask Array +- Polars Dataframe + +.csv: +- Pandas Dataframe +- Dask Dataframe +- Polars Dataframe + +.npy_stack: +- Dask Array + +.pkl: +- Any Object diff --git a/epochlib/caching/__init__.py b/epochlib/caching/__init__.py index a22b347..5bb50a5 100644 --- a/epochlib/caching/__init__.py +++ b/epochlib/caching/__init__.py @@ -1,5 +1,34 @@ """Caching module for epochlib.""" from .cacher import CacheArgs, Cacher +from .any_to_pkl_cacher import AnyToPklCacher +from .cacher_interface import CacherInterface +from .dask_array_to_npy_cacher import DaskArrayToNpyCacher +from .dask_array_to_npy_stack_cacher import DaskArrayToNpyStackCacher +from .dask_array_to_parquet_cacher import DaskArrayToParquetCacher +from .dask_dataframe_to_csv_cacher import DaskDataFrameToCSVCacher +from .dask_dataframe_to_parquet_cacher import DaskDataFrameToParquetCacher +from .numpy_array_to_npy_cacher import NumpyArrayToNpyCacher +from .numpy_array_to_parquet_cacher import NumpyArrayToParquetCacher +from .pandas_dataframe_to_csv_cacher import PandasDataFrameToCSVCacher +from .pandas_dataframe_to_parquet_cacher import PandasDataFrameToParquetCacher +from .polars_dataframe_to_csv_cacher import PolarsDataFrameToCSVCacher +from .polars_dataframe_to_parquet_cacher import PolarsDataFrameToParquetCacher -__all__ = ["Cacher", "CacheArgs"] +__all__ = [ + "Cacher", + "CacheArgs", + "CacherInterface", + "NumpyArrayToNpyCacher", + "DaskArrayToNpyCacher", + "AnyToPklCacher", + "DaskArrayToNpyStackCacher", + "DaskArrayToParquetCacher", + "DaskDataFrameToCSVCacher", + "DaskDataFrameToParquetCacher", + "NumpyArrayToParquetCacher", + "PandasDataFrameToCSVCacher", + "PandasDataFrameToParquetCacher", + "PolarsDataFrameToCSVCacher", + "PolarsDataFrameToParquetCacher", +] diff --git a/epochlib/caching/any_to_pkl_cacher.py b/epochlib/caching/any_to_pkl_cacher.py new file mode 100644 index 0000000..a65b733 --- /dev/null +++ b/epochlib/caching/any_to_pkl_cacher.py @@ -0,0 +1,50 @@ +"""This module contains the AnyToPklCacher class.""" + +import os +import pickle +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .cacher_interface import CacherInterface + + +@dataclass +class AnyToPklCacher(CacherInterface): + """The any to .pkl cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, pickle.load() extra args. + :param store_args: The arguments to store the cache, pickle.dump() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.pkl") + + def load_cache(self, name: str) -> Any: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + with open(Path(self.storage_path) / f"{name}.pkl", "rb") as file: + return pickle.load(file, **self.read_args) # noqa: S301 + + def store_cache(self, name: str, data: Any) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + with open(storage_path / f"{name}.pkl", "wb") as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL, **self.store_args) diff --git a/epochlib/caching/cacher_interface.py b/epochlib/caching/cacher_interface.py new file mode 100644 index 0000000..d910596 --- /dev/null +++ b/epochlib/caching/cacher_interface.py @@ -0,0 +1,40 @@ +"""The cache interface defines the methods that a cacher must implement.""" + +from typing import Any + + +class CacherInterface: + """The cache interface defines the methods that a cacher must implement. + + Methods + ------- + .. code-block:: python + def cache_exists(name: str) -> bool + + def load_cache(name: str) -> Any + + def store_cache(name: str, data: Any) -> None + """ + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + raise NotImplementedError(f"cache_exists from CacheInterface not implemented in {self.__class__.__name__}.") + + def load_cache(self, name: str) -> Any: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + raise NotImplementedError(f"load_cache from CacheInterface not implemented in {self.__class__.__name__}.") + + def store_cache(self, name: str, data: Any) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + :param data: The data to store in the cache. + :param store_args: Any additional arguments to store the + """ + raise NotImplementedError(f"store_cache from CacheInterface not implemented in {self.__class__.__name__}.") diff --git a/epochlib/caching/dask_array_to_npy_cacher.py b/epochlib/caching/dask_array_to_npy_cacher.py new file mode 100644 index 0000000..1dfa4f7 --- /dev/null +++ b/epochlib/caching/dask_array_to_npy_cacher.py @@ -0,0 +1,65 @@ +"""This module contains the DaskArrayToNpyCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np + +from .cacher_interface import CacherInterface + +try: + import dask.array as da +except ImportError: + """User doesn't require these packages""" + + +@dataclass +class DaskArrayToNpyCacher(CacherInterface): + """The dask array to .npy cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, da.from_array() extra args. + :param store_args: The arguments to store the cache, np.save() extra args. + + Methods + ------- + .. code-block:: python + def cache_exists(name: str) -> bool + + def load_cache(name: str) -> Any + + def store_cache(name: str, data: Any) -> None + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.npy") + + def load_cache(self, name: str) -> da.Array: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + :param cache_args: The cache arguments. + """ + return da.from_array(np.load(Path(self.storage_path) / f"{name}.npy"), **self.read_args) + + def store_cache(self, name: str, data: da.Array) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + :param cache_args: The cache arguments. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + np.save(storage_path / f"{name}.npy", data.compute(), **self.store_args) diff --git a/epochlib/caching/dask_array_to_npy_stack_cacher.py b/epochlib/caching/dask_array_to_npy_stack_cacher.py new file mode 100644 index 0000000..2fbbffa --- /dev/null +++ b/epochlib/caching/dask_array_to_npy_stack_cacher.py @@ -0,0 +1,49 @@ +"""This module contains the DaskArrayToNpyStackCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import dask.array as da + +from .cacher_interface import CacherInterface + + +@dataclass +class DaskArrayToNpyStackCacher(CacherInterface): + """The dask array to .npy_stack cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, da.from_npy_stack() extra args. + :param store_args: The arguments to store the cache, da.to_npy_stack() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / name) + + def load_cache(self, name: str) -> da.Array: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return da.from_npy_stack(Path(self.storage_path) / name, **self.read_args) + + def store_cache(self, name: str, data: da.Array) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + da.to_npy_stack(storage_path / name, data, **self.store_args) diff --git a/epochlib/caching/dask_array_to_parquet_cacher.py b/epochlib/caching/dask_array_to_parquet_cacher.py new file mode 100644 index 0000000..0ded2da --- /dev/null +++ b/epochlib/caching/dask_array_to_parquet_cacher.py @@ -0,0 +1,54 @@ +"""This module contains the DaskArrayToParquetCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import dask.array as da +import dask.dataframe as dd + +from .cacher_interface import CacherInterface + + +@dataclass +class DaskArrayToParquetCacher(CacherInterface): + """The dask array to .parquet cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, dd.read_parquet() extra args. + :param store_args: The arguments to store the cache, dd.to_parquet() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.parquet") + + def load_cache(self, name: str) -> da.Array: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return dd.read_parquet(Path(self.storage_path) / f"{name}.parquet", **self.read_args).to_dask_array() + + def store_cache(self, name: str, data: da.Array) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + new_dd = dd.from_dask_array(data) + new_dd = new_dd.rename( + columns={col: str(col) for col in new_dd.columns}, + ) + new_dd.to_parquet(storage_path / f"{name}.parquet", **self.store_args) diff --git a/epochlib/caching/dask_dataframe_to_csv_cacher.py b/epochlib/caching/dask_dataframe_to_csv_cacher.py new file mode 100644 index 0000000..760b175 --- /dev/null +++ b/epochlib/caching/dask_dataframe_to_csv_cacher.py @@ -0,0 +1,50 @@ +"""This module contains the DaskDataFrameToCSVCacher class.""" + +import glob +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import dask.dataframe as dd + +from .cacher_interface import CacherInterface + + +@dataclass +class DaskDataFrameToCSVCacher(CacherInterface): + """The dask dataframe to .csv cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, dd.read_csv() extra args. + :param store_args: The arguments to store the cache, dd.to_csv() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.csv") or glob.glob(str(storage_path / name / "*.part")) != [] + + def load_cache(self, name: str) -> dd.DataFrame: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return dd.read_csv(Path(self.storage_path) / name / "*.part", **self.read_args) + + def store_cache(self, name: str, data: dd.DataFrame) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + data.to_csv(storage_path / name, index=False, **self.store_args) diff --git a/epochlib/caching/dask_dataframe_to_parquet_cacher.py b/epochlib/caching/dask_dataframe_to_parquet_cacher.py new file mode 100644 index 0000000..45aa842 --- /dev/null +++ b/epochlib/caching/dask_dataframe_to_parquet_cacher.py @@ -0,0 +1,49 @@ +"""This module contains the DaskDataFrameToParquetCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import dask.dataframe as dd + +from .cacher_interface import CacherInterface + + +@dataclass +class DaskDataFrameToParquetCacher(CacherInterface): + """The dask dataframe to .parquet cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, dd.read_parquet() extra args. + :param store_args: The arguments to store the cache, dd.to_parquet() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.parquet") + + def load_cache(self, name: str) -> dd.DataFrame: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return dd.read_parquet(Path(self.storage_path) / f"{name}.parquet", **self.read_args) + + def store_cache(self, name: str, data: dd.DataFrame) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + data.to_parquet(storage_path / f"{name}.parquet", **self.store_args) diff --git a/epochlib/caching/numpy_array_to_npy_cacher.py b/epochlib/caching/numpy_array_to_npy_cacher.py new file mode 100644 index 0000000..b6d27a7 --- /dev/null +++ b/epochlib/caching/numpy_array_to_npy_cacher.py @@ -0,0 +1,61 @@ +"""This module contains the NumpyToNpyCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +from numpy.typing import ArrayLike + +from .cacher_interface import CacherInterface + + +@dataclass +class NumpyArrayToNpyCacher(CacherInterface): + """The numpy array to .npy cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, np.load() extra args. + :param store_args: The arguments to store the cache, np.save() extra args. + + Methods + ------- + .. code-block:: python + def cache_exists(name: str) -> bool + + def load_cache(name: str) -> Any + + def store_cache(name: str, data: Any) -> None + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.npy") + + def load_cache(self, name: str) -> ArrayLike: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + :param cache_args: The cache arguments. + """ + return np.load(Path(self.storage_path) / f"{name}.npy", **self.read_args) + + def store_cache(self, name: str, data: ArrayLike) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + :param cache_args: The cache arguments. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + np.save(storage_path / f"{name}.npy", data, **self.store_args) diff --git a/epochlib/caching/numpy_array_to_parquet_cacher.py b/epochlib/caching/numpy_array_to_parquet_cacher.py new file mode 100644 index 0000000..a5bbfc5 --- /dev/null +++ b/epochlib/caching/numpy_array_to_parquet_cacher.py @@ -0,0 +1,50 @@ +"""This module contains the NumpyArrayToParquetCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pandas as pd +from numpy.typing import ArrayLike + +from .cacher_interface import CacherInterface + + +@dataclass +class NumpyArrayToParquetCacher(CacherInterface): + """The numpy array to .parquet cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, pd.read_parquet() extra args. + :param store_args: The arguments to store the cache, pd.to_parquet() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.parquet") + + def load_cache(self, name: str) -> ArrayLike: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return pd.read_parquet(Path(self.storage_path) / f"{name}.parquet", **self.read_args).to_numpy() + + def store_cache(self, name: str, data: ArrayLike) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + pd.DataFrame(data).to_parquet(storage_path / f"{name}.parquet", **self.store_args) diff --git a/epochlib/caching/pandas_dataframe_to_csv_cacher.py b/epochlib/caching/pandas_dataframe_to_csv_cacher.py new file mode 100644 index 0000000..0349de4 --- /dev/null +++ b/epochlib/caching/pandas_dataframe_to_csv_cacher.py @@ -0,0 +1,49 @@ +"""This module contains the PandasDataFrameToCSVCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pandas as pd + +from .cacher_interface import CacherInterface + + +@dataclass +class PandasDataFrameToCSVCacher(CacherInterface): + """The pandas dataframe to .csv cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, pd.read_csv() extra args. + :param store_args: The arguments to store the cache, pd.to_csv() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.csv") + + def load_cache(self, name: str) -> pd.DataFrame: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return pd.read_csv(Path(self.storage_path) / f"{name}.csv", **self.read_args) + + def store_cache(self, name: str, data: pd.DataFrame) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + data.to_csv(storage_path / f"{name}.csv", index=False, **self.store_args) diff --git a/epochlib/caching/pandas_dataframe_to_parquet_cacher.py b/epochlib/caching/pandas_dataframe_to_parquet_cacher.py new file mode 100644 index 0000000..735934b --- /dev/null +++ b/epochlib/caching/pandas_dataframe_to_parquet_cacher.py @@ -0,0 +1,49 @@ +"""This module contains the PandasDataFrameToParquetCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pandas as pd + +from .cacher_interface import CacherInterface + + +@dataclass +class PandasDataFrameToParquetCacher(CacherInterface): + """The pandas dataframe to .parquet cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, pd.read_parquet() extra args. + :param store_args: The arguments to store the cache, pd.to_parquet() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.parquet") + + def load_cache(self, name: str) -> pd.DataFrame: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return pd.read_parquet(Path(self.storage_path) / f"{name}.parquet", **self.read_args) + + def store_cache(self, name: str, data: pd.DataFrame) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + data.to_parquet(storage_path / f"{name}.parquet", **self.store_args) diff --git a/epochlib/caching/polars_dataframe_to_csv_cacher.py b/epochlib/caching/polars_dataframe_to_csv_cacher.py new file mode 100644 index 0000000..0c81dcd --- /dev/null +++ b/epochlib/caching/polars_dataframe_to_csv_cacher.py @@ -0,0 +1,49 @@ +"""This module contains the PolarsDataFrameToCSVCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import polars as pl + +from .cacher_interface import CacherInterface + + +@dataclass +class PolarsDataFrameToCSVCacher(CacherInterface): + """The polars dataframe to .csv cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, pl.read_csv() extra args. + :param store_args: The arguments to store the cache, pl.write_csv() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.csv") + + def load_cache(self, name: str) -> pl.DataFrame: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return pl.read_csv(Path(self.storage_path) / f"{name}.csv", **self.read_args) + + def store_cache(self, name: str, data: pl.DataFrame) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + data.write_csv(storage_path / f"{name}.csv", **self.store_args) diff --git a/epochlib/caching/polars_dataframe_to_parquet_cacher.py b/epochlib/caching/polars_dataframe_to_parquet_cacher.py new file mode 100644 index 0000000..90d5426 --- /dev/null +++ b/epochlib/caching/polars_dataframe_to_parquet_cacher.py @@ -0,0 +1,49 @@ +"""This module contains the PolarsDataFrameToParquetCacher class.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import polars as pl + +from .cacher_interface import CacherInterface + + +@dataclass +class PolarsDataFrameToParquetCacher(CacherInterface): + """The polars dataframe to .parquet cacher. + + :param storage_path: The path to store the cache files. + :param read_args: The arguments to read the cache, pl.read_parquet() extra args. + :param store_args: The arguments to store the cache, pl.write_parquet() extra args. + """ + + storage_path: str + read_args: dict[str, Any] = field(default_factory=dict) + store_args: dict[str, Any] = field(default_factory=dict) + + def cache_exists(self, name: str) -> bool: + """Check if a cache exists. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + + return os.path.exists(storage_path / f"{name}.parquet") + + def load_cache(self, name: str) -> pl.DataFrame: + """Load a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + return pl.read_parquet(Path(self.storage_path) / f"{name}.parquet", **self.read_args) + + def store_cache(self, name: str, data: pl.DataFrame) -> None: + """Store a cache. + + :param name: The name of the cache, cannot contain characters not in [a-zA-Z0-9_]. + """ + storage_path = Path(self.storage_path) + storage_path.mkdir(parents=True, exist_ok=True) + data.write_parquet(storage_path / f"{name}.parquet", **self.store_args) diff --git a/tests/caching/test_any_to_pkl_cacher.py b/tests/caching/test_any_to_pkl_cacher.py new file mode 100644 index 0000000..f5c978a --- /dev/null +++ b/tests/caching/test_any_to_pkl_cacher.py @@ -0,0 +1,28 @@ +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import AnyToPklCacher + + +class TestAnyToPklCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = AnyToPklCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = AnyToPklCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = AnyToPklCacher(self.cache_path) + data = {'a': 1, 'b': 2} + c.store_cache("name", data) + assert c.cache_exists("name") + assert c.load_cache("name") == data diff --git a/tests/caching/test_cache_interface.py b/tests/caching/test_cache_interface.py new file mode 100644 index 0000000..0ba217b --- /dev/null +++ b/tests/caching/test_cache_interface.py @@ -0,0 +1,20 @@ +import pytest +from epochlib.caching import CacherInterface + + +class Test_Cache_Interface: + + def test_cache_exists_raises_error(self): + c = CacherInterface() + with pytest.raises(NotImplementedError): + c.cache_exists("name") + + def test_load_cache_raises_error(self): + c = CacherInterface() + with pytest.raises(NotImplementedError): + c.load_cache("name") + + def test_store_cache_raises_error(self): + c = CacherInterface() + with pytest.raises(NotImplementedError): + c.store_cache("name", "data") diff --git a/tests/caching/test_dask_array_to_npy_cacher.py b/tests/caching/test_dask_array_to_npy_cacher.py new file mode 100644 index 0000000..524e814 --- /dev/null +++ b/tests/caching/test_dask_array_to_npy_cacher.py @@ -0,0 +1,29 @@ +import dask.array as da +import pytest +from tests.constants import TEMP_DIR +from epochlib.caching import DaskArrayToNpyCacher + + +class Test_DaskArrayToNpyCacher: + + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = DaskArrayToNpyCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = DaskArrayToNpyCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = DaskArrayToNpyCacher(self.cache_path) + x = da.ones((1000, 1000), chunks=(100, 100)) + c.store_cache("name", x) + assert c.cache_exists("name") + assert c.load_cache("name").shape == (1000, 1000) diff --git a/tests/caching/test_dask_array_to_npy_stack_cacher.py b/tests/caching/test_dask_array_to_npy_stack_cacher.py new file mode 100644 index 0000000..4dcceaa --- /dev/null +++ b/tests/caching/test_dask_array_to_npy_stack_cacher.py @@ -0,0 +1,31 @@ +import dask.array as da +import numpy as np +import pytest +from dask.array.utils import assert_eq +from tests.constants import TEMP_DIR + +from epochlib.caching import DaskArrayToNpyStackCacher + + +class TestDaskArrayToNpyStackCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = DaskArrayToNpyStackCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = DaskArrayToNpyStackCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = DaskArrayToNpyStackCacher(self.cache_path) + arr = da.from_array(np.array([1, 2, 3]), chunks=1) + c.store_cache("name", arr) + assert c.cache_exists("name") + assert_eq(c.load_cache("name"), arr) diff --git a/tests/caching/test_dask_array_to_parquet_cacher.py b/tests/caching/test_dask_array_to_parquet_cacher.py new file mode 100644 index 0000000..deb867d --- /dev/null +++ b/tests/caching/test_dask_array_to_parquet_cacher.py @@ -0,0 +1,31 @@ +import dask.array as da +import numpy as np +import pytest +from dask.array.utils import assert_eq +from tests.constants import TEMP_DIR + +from epochlib.caching import DaskArrayToParquetCacher + + +class TestDaskArrayToParquetCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = DaskArrayToParquetCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = DaskArrayToParquetCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = DaskArrayToParquetCacher(self.cache_path) + arr = da.from_array(np.array([1, 2, 3]), chunks=1) + c.store_cache("name", arr) + assert c.cache_exists("name") + assert_eq(c.load_cache("name"), arr) diff --git a/tests/caching/test_dask_dataframe_to_csv_cacher.py b/tests/caching/test_dask_dataframe_to_csv_cacher.py new file mode 100644 index 0000000..7ef6a3c --- /dev/null +++ b/tests/caching/test_dask_dataframe_to_csv_cacher.py @@ -0,0 +1,31 @@ +import dask.dataframe as dd +import pandas as pd +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import DaskDataFrameToCSVCacher + + +class TestDaskDataFrameToCSVCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = DaskDataFrameToCSVCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = DaskDataFrameToCSVCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = DaskDataFrameToCSVCacher(self.cache_path) + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + ddf = dd.from_pandas(df, npartitions=2) + c.store_cache("name", ddf) + assert c.cache_exists("name") + dd.assert_eq(c.load_cache("name"), ddf, check_index=False) diff --git a/tests/caching/test_dask_dataframe_to_parquet_cacher.py b/tests/caching/test_dask_dataframe_to_parquet_cacher.py new file mode 100644 index 0000000..63330fe --- /dev/null +++ b/tests/caching/test_dask_dataframe_to_parquet_cacher.py @@ -0,0 +1,31 @@ +import dask.dataframe as dd +import pandas as pd +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import DaskDataFrameToParquetCacher + + +class TestDaskDataFrameToParquetCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = DaskDataFrameToParquetCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = DaskDataFrameToParquetCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = DaskDataFrameToParquetCacher(self.cache_path) + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + ddf = dd.from_pandas(df, npartitions=2) + c.store_cache("name", ddf) + assert c.cache_exists("name") + dd.assert_eq(c.load_cache("name"), ddf) diff --git a/tests/caching/test_numpy_array_to_npy_cacher.py b/tests/caching/test_numpy_array_to_npy_cacher.py new file mode 100644 index 0000000..897a5a3 --- /dev/null +++ b/tests/caching/test_numpy_array_to_npy_cacher.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +from tests.constants import TEMP_DIR +from epochlib.caching import NumpyArrayToNpyCacher + + +class Test_NumpyArrayToNpyCacher: + + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = NumpyArrayToNpyCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = NumpyArrayToNpyCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = NumpyArrayToNpyCacher(self.cache_path) + c.store_cache("name", np.array([1, 2, 3])) + assert c.cache_exists("name") + assert np.array_equal(c.load_cache("name"), np.array([1, 2, 3])) diff --git a/tests/caching/test_numpy_array_to_parquet_cacher.py b/tests/caching/test_numpy_array_to_parquet_cacher.py new file mode 100644 index 0000000..6e5d7f9 --- /dev/null +++ b/tests/caching/test_numpy_array_to_parquet_cacher.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import NumpyArrayToParquetCacher + + +class TestNumpyArrayToParquetCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = NumpyArrayToParquetCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = NumpyArrayToParquetCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = NumpyArrayToParquetCacher(self.cache_path) + arr = np.array([1, 2, 3]) + c.store_cache("name", arr) + assert c.cache_exists("name") + np.testing.assert_array_equal(c.load_cache("name"), arr) diff --git a/tests/caching/test_pandas_dataframe_to_csv_cacher.py b/tests/caching/test_pandas_dataframe_to_csv_cacher.py new file mode 100644 index 0000000..2d65869 --- /dev/null +++ b/tests/caching/test_pandas_dataframe_to_csv_cacher.py @@ -0,0 +1,29 @@ +import pandas as pd +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import PandasDataFrameToCSVCacher + + +class TestPandasDataFrameToCSVCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = PandasDataFrameToCSVCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = PandasDataFrameToCSVCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = PandasDataFrameToCSVCacher(self.cache_path) + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + c.store_cache("name", df) + assert c.cache_exists("name") + pd.testing.assert_frame_equal(c.load_cache("name"), df) diff --git a/tests/caching/test_pandas_dataframe_to_parquet_cacher.py b/tests/caching/test_pandas_dataframe_to_parquet_cacher.py new file mode 100644 index 0000000..67df7b2 --- /dev/null +++ b/tests/caching/test_pandas_dataframe_to_parquet_cacher.py @@ -0,0 +1,29 @@ +import pandas as pd +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import PandasDataFrameToParquetCacher + + +class TestPandasDataFrameToParquetCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = PandasDataFrameToParquetCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = PandasDataFrameToParquetCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = PandasDataFrameToParquetCacher(self.cache_path) + df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + c.store_cache("name", df) + assert c.cache_exists("name") + pd.testing.assert_frame_equal(c.load_cache("name"), df) diff --git a/tests/caching/test_polars_dataframe_to_csv_cacher.py b/tests/caching/test_polars_dataframe_to_csv_cacher.py new file mode 100644 index 0000000..8ed0aa8 --- /dev/null +++ b/tests/caching/test_polars_dataframe_to_csv_cacher.py @@ -0,0 +1,29 @@ +import polars as pl +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import PolarsDataFrameToCSVCacher + + +class TestPolarsDataFrameToCSVCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = PolarsDataFrameToCSVCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = PolarsDataFrameToCSVCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = PolarsDataFrameToCSVCacher(self.cache_path) + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + c.store_cache("name", df) + assert c.cache_exists("name") + assert c.load_cache("name").equals(df) diff --git a/tests/caching/test_polars_dataframe_to_parquet_cacher.py b/tests/caching/test_polars_dataframe_to_parquet_cacher.py new file mode 100644 index 0000000..b74ae03 --- /dev/null +++ b/tests/caching/test_polars_dataframe_to_parquet_cacher.py @@ -0,0 +1,29 @@ +import polars as pl +import pytest +from tests.constants import TEMP_DIR + +from epochlib.caching import PolarsDataFrameToParquetCacher + + +class TestPolarsDataFrameToParquetCacher: + cache_path = TEMP_DIR + + @pytest.fixture(autouse=True) + def run_always(self, setup_temp_dir): + pass + + def test_cache_exists(self): + c = PolarsDataFrameToParquetCacher(self.cache_path) + assert not c.cache_exists("name") + + def test_load_cache(self): + c = PolarsDataFrameToParquetCacher(self.cache_path) + with pytest.raises(FileNotFoundError): + c.load_cache("name") + + def test_store_cache(self): + c = PolarsDataFrameToParquetCacher(self.cache_path) + df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) + c.store_cache("name", df) + assert c.cache_exists("name") + assert c.load_cache("name").equals(df)