diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 82e4822..77b7eb4 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -1,3 +1,6 @@
-*.yaml @TeamEpochGithub/core-maintainers
-*.yml @TeamEpochGithub/core-maintainers
-*.toml @TeamEpochGithub/core-maintainers
+*.py @TeamEpochGithub/packages
+*.md @TeamEpochGithub/packages
+*.yaml @TeamEpochGithub/packages
+*.yml @TeamEpochGithub/packages
+*.toml @TeamEpochGithub/packages
+*.txt @TeamEpochGithub/packages
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index 586279b..06c5e39 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -4,6 +4,7 @@ updates:
directory: "/"
schedule:
interval: "monthly"
+ target-branch: "dependabot-updates"
- package-ecosystem: "pip"
directory: "/"
schedule:
diff --git a/.github/workflows/label-issue.yml b/.github/workflows/label-issue.yml
index d0c4860..f17bf1e 100644
--- a/.github/workflows/label-issue.yml
+++ b/.github/workflows/label-issue.yml
@@ -13,10 +13,10 @@ jobs:
issues: write
steps:
- name: checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- name: Check if issue has a milestone
id: check_milestone
- uses: actions/github-script@v6
+ uses: actions/github-script@v7
with:
script: |
const issue = context.payload.issue;
diff --git a/.github/workflows/main-branch-testing.yml b/.github/workflows/main-branch-testing.yml
index f6fc0de..221b4df 100644
--- a/.github/workflows/main-branch-testing.yml
+++ b/.github/workflows/main-branch-testing.yml
@@ -2,50 +2,46 @@ name: Main Branch CI/CD
on:
push:
- branches: [ "main" ]
+ branches: ["main"]
pull_request:
- branches: [ "main" ]
+ branches: ["main"]
jobs:
- build:
+ pytest:
runs-on: ubuntu-latest
- container:
- image: python:3.11-slim
- env:
- NODE_VERSION: 20
-
strategy:
- fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
-
steps:
- - name: Install Node.js
- uses: actions/setup-node@v3
+ - name: Check out repository
+ uses: actions/checkout@v4
with:
- node-version: ${{ env.NODE_VERSION }}
-
- - uses: actions/checkout@v3
-
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v3
+ fetch-depth: 0
+ - name: Install the latest version of Rye
+ uses: eifinger/setup-rye@v4.2.1
with:
python-version: ${{ matrix.python-version }}
-
- - name: Create virtual environment
- run: python -m venv venv
-
- - name: Activate virtual environment
- run: |
- . venv/bin/activate
-
- - name: Install dependencies
- run: |
- venv/bin/python -m pip install --upgrade pip
- venv/bin/python -m pip install pytest
- venv/bin/python -m pip install -r requirements-dev.lock
- venv/bin/python -m pip install pytest-cov coverage
-
+ - name: Setup the environment
+ run: rye sync --all-features
- name: Test with pytest
- run: |
- venv/bin/python -m pytest --cov=epochalyst --cov-branch --cov-fail-under=95 tests
+ run: rye run pytest --cov=epochlib --cov-branch --cov-fail-under=75 tests
+
+ build:
+ runs-on: ubuntu-latest
+ needs: pytest
+ strategy:
+ matrix:
+ python-version: ["3.10", "3.11", "3.12"]
+ steps:
+ - name: Check out repository
+ uses: actions/checkout@v4
+ - name: Install the latest version of Rye
+ uses: eifinger/setup-rye@v4.2.1
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Build the package
+ run: rye build
+ - uses: actions/upload-artifact@v4.3.6
+ with:
+ path: ./dist
+ name: dist-python-${{ matrix.python-version }}
diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml
index 3626ef0..0c21625 100644
--- a/.github/workflows/publish-package.yml
+++ b/.github/workflows/publish-package.yml
@@ -8,12 +8,12 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out repository
- uses: actions/checkout@v4.1.6
+ uses: actions/checkout@v4
- name: Install the latest version of Rye
- uses: eifinger/setup-rye@v3.0.2
+ uses: eifinger/setup-rye@v4.2.1
- name: Build the package
run: rye build
- - uses: actions/upload-artifact@v4.3.3
+ - uses: actions/upload-artifact@v4.3.6
with:
path: ./dist
@@ -22,12 +22,12 @@ jobs:
runs-on: ubuntu-latest
environment:
name: pypi
- url: https://pypi.org/p/epochalyst
+ url: https://pypi.org/p/epochlib
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
steps:
- - uses: actions/download-artifact@v4.1.7
+ - uses: actions/download-artifact@v4.1.8
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
diff --git a/.github/workflows/retarget-dependabot-pr.yml b/.github/workflows/retarget-dependabot-pr.yml
new file mode 100644
index 0000000..2b89d24
--- /dev/null
+++ b/.github/workflows/retarget-dependabot-pr.yml
@@ -0,0 +1,35 @@
+name: Retarget Dependabot PR
+
+on:
+ pull_request:
+ branches:
+ - dependabot-updates
+ types:
+ - opened
+ - synchronize
+
+jobs:
+ retarget:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Get all branches
+ run: |
+ git fetch --all
+ git branch -r > branches.txt
+ cat branches.txt
+
+ - name: Find the lowest version branch
+ id: find_branch
+ run: |
+ lowest_branch=$(grep -o 'origin/v[0-9]*\.[0-9]*' branches.txt | sort -V | head -n 1 | sed 's/origin\///')
+ echo "lowest_branch=$lowest_branch" >> $GITHUB_ENV
+
+ - name: Retarget PR to lowest version branch
+ if: success()
+ run: |
+ gh pr edit ${{ github.event.pull_request.number }} --base ${{ env.lowest_branch }}
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/static-analysis.yml b/.github/workflows/static-analysis.yml
index e8f2247..0cada10 100644
--- a/.github/workflows/static-analysis.yml
+++ b/.github/workflows/static-analysis.yml
@@ -11,11 +11,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out repository
- uses: actions/checkout@v4.1.6
+ uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python 3.10.14
- uses: actions/setup-python@v5.1.0
+ uses: actions/setup-python@v5
with:
python-version: 3.10.14
- name: Run pre-commit
diff --git a/.github/workflows/version-branch-testing.yml b/.github/workflows/version-branch-testing.yml
index 264ac2a..056e1f5 100644
--- a/.github/workflows/version-branch-testing.yml
+++ b/.github/workflows/version-branch-testing.yml
@@ -9,25 +9,25 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out repository
- uses: actions/checkout@v4.1.6
+ uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install the latest version of Rye
- uses: eifinger/setup-rye@v3.0.2
+ uses: eifinger/setup-rye@v4.2.1
- name: Setup the environment
run: rye sync --all-features
- name: Test with pytest
- run: rye run pytest --cov=epochalyst --cov-branch --cov-fail-under=95 tests
+ run: rye run pytest --cov=epochlib --cov-branch --cov-fail-under=75 tests
build:
runs-on: ubuntu-latest
steps:
- name: Check out repository
- uses: actions/checkout@v4.1.6
+ uses: actions/checkout@v4
- name: Install the latest version of Rye
- uses: eifinger/setup-rye@v3.0.2
+ uses: eifinger/setup-rye@v4.2.1
- name: Build the package
run: rye build
- - uses: actions/upload-artifact@v4.3.3
+ - uses: actions/upload-artifact@v4.3.6
with:
path: ./dist
diff --git a/.gitignore b/.gitignore
index 6407b02..9077972 100644
--- a/.gitignore
+++ b/.gitignore
@@ -164,9 +164,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
-# Ignore data files
-data/
-
# Ignore logs
/logging/
/wandb/
diff --git a/CITATION.cff b/CITATION.cff
index f012da8..617d1f0 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -1,6 +1,11 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
+- family-names: "Selm"
+ name-particle: "van"
+ given-names: "Jasper"
+ affiliation: "TU Delft Dream Team Epoch"
+ email: "jmvanselm@gmail.com"
- family-names: "Lim"
given-names: "Jeffrey"
affiliation: "TU Delft Dream Team Epoch"
@@ -18,11 +23,6 @@ authors:
given-names: "Cahit Tolga"
affiliation: "TU Delft Dream Team Epoch"
email: "cahittolgakopar@gmail.com"
-- family-names: "Selm"
- name-particle: "van"
- given-names: "Jasper"
- affiliation: "TU Delft Dream Team Epoch"
- email: "jmvanselm@gmail.com"
- family-names: "Ebersberger"
given-names: "Ariel"
affiliation: "TU Delft Dream Team Epoch"
@@ -39,7 +39,7 @@ authors:
given-names: "Daniel"
affiliation: "TU Delft Dream Team Epoch"
email: "danieldediosallegue@gmail.com"
-title: "Epochalyst"
+title: "EpochLib"
version: 1.0.0
date-released: 2024-03-19
-url: "https://github.com/TeamEpochGithub/epochalyst"
+url: "https://github.com/TeamEpochGithub/epochlib"
diff --git a/README.md b/README.md
index e1c33c9..57f8f5e 100644
--- a/README.md
+++ b/README.md
@@ -1,49 +1,47 @@
-
+
[](https://teamepoch.ai/)
-[](https://pypi.org/project/epochalyst/)
-[](https://pypi.org/project/epochalyst/)
+[](https://pypi.org/project/epochlib/)
+[](https://pypi.org/project/epochlib/)
[](https://www.python.org/downloads/)
[](https://rye-up.com)
[](https://github.com/astral-sh/ruff)
[](https://mypy-lang.org/)
-[](https://results.pre-commit.ci/latest/github/TeamEpochGithub/epochalyst/main)
+[](https://results.pre-commit.ci/latest/github/TeamEpochGithub/epochlib/main)
-Epochalyst is the base for [Team Epoch](https://teamepoch.ai/) competitions.
+EpochLib is the base for [Team Epoch](https://teamepoch.ai/) competitions.
-This package contains many modules and classes necessary to construct the src code for machine learning competitions.
-
-Epochalyst: A fusion of "Epoch" and "Catalyst," this name positions your pipeline as a catalyst in the field of machine learning, sparking significant advancements and transformations.
+This library package contains many modules and classes necessary to construct the src code for machine learning competitions.
## Installation
-Install `epochalyst` using [Rye](https://rye.astral.sh/):
+Install `epochlib` using [Rye](https://rye.astral.sh/):
```shell
-rye add epochalyst
+rye add epochlib
```
Or via pip:
```shell
-pip install epochalyst
+pip install epochlib
```
### Optional Dependencies
Depending on what data libraries you use, you can install the following optional dependencies:
```shell
-rye add epochalyst[numpy,pandas,dask,polars] # Pick one or more of these
+rye add epochlib[numpy,pandas,dask,polars] # Pick one or more of these
```
Depending on what type of competition you are participating in, you can install the following optional dependencies:
```shell
-rye add epochalyst[image,audio] # Pick one or more of these
+rye add epochlib[image,audio] # Pick one or more of these
```
Aside from that, you can install the following optional dependencies:
```shell
-rye add epochalyst[onnx,openvino] # Optimizing model inference
+rye add epochlib[onnx,openvino] # Optimizing model inference
```
## Pytest coverage report
@@ -51,7 +49,7 @@ rye add epochalyst[onnx,openvino] # Optimizing model inference
To generate pytest coverage report run
```shell
-rye run pytest --cov=epochalyst --cov-branch --cov-report=html:coverage_re
+rye run pytest --cov=epochlib --cov-branch --cov-report=html:coverage_re
```
## pre-commit
@@ -72,21 +70,11 @@ rye run pre-commit run --all-files
## Documentation
-Documentation is generated using [Sphinx](https://www.sphinx-doc.org/en/master/) and can be found [here](https://teamepochgithub.github.io/epochalyst).
+Documentation is generated using [Sphinx](https://www.sphinx-doc.org/en/master/) and can be found [here](https://teamepochgithub.github.io/epochlib).
To make the documentation yourself, run `make html` with `docs` as the working directory.
The documentation can then be found in `docs/_build/html/index.html`.
-## Contributors
-
-Epochalyst was created by [Team Epoch IV](https://teamepoch.ai/team#iv), based in the [Dream Hall](https://www.tudelft.nl/ddream) of the [Delft University of Technology](https://www.tudelft.nl/).
+## Maintainers
-[](https://github.com/EWitting)
-[](https://github.com/Jeffrey-Lim)
-[](https://github.com/hjdeheer)
-[](https://github.com/schobbejak)
-[](https://github.com/tolgakopar)
-[](https://github.com/justanotherariel)
-[](https://github.com/Gregoire-Andre-Dumont)
-[](https://github.com/emherk)
-[](https://github.com/daniallegue)
+EpochLib is maintained by [Team Epoch](https://teamepoch.ai), based in the [Dream Hall](https://www.tudelft.nl/ddream) of the [Delft University of Technology](https://www.tudelft.nl/).
diff --git a/assets/Epochalyst_Icon.svg b/assets/EpochLib_Icon.svg
similarity index 99%
rename from assets/Epochalyst_Icon.svg
rename to assets/EpochLib_Icon.svg
index 0971d2f..aaf4e69 100644
--- a/assets/Epochalyst_Icon.svg
+++ b/assets/EpochLib_Icon.svg
@@ -8,7 +8,7 @@
version="1.1"
id="svg1"
xml:space="preserve"
- sodipodi:docname="Epochalyst_Icon.svg"
+ sodipodi:docname="EpochLib.svg"
inkscape:version="1.3.2 (091e20e, 2023-11-25, custom)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
@@ -54,4 +54,4 @@
clip-path="url(#clipPath4)"
transform="matrix(1.3085863,0,0,1.3085863,-523.43452,-130.85863)"
style="display:inline"
- inkscape:label="Epochalyst Icon" />
+ inkscape:label="EpochLib Icon" />
diff --git a/assets/Epochalyst_Logo_Auto.svg b/assets/EpochLib_Logo_Auto.svg
similarity index 99%
rename from assets/Epochalyst_Logo_Auto.svg
rename to assets/EpochLib_Logo_Auto.svg
index d482271..09b0f71 100644
--- a/assets/Epochalyst_Logo_Auto.svg
+++ b/assets/EpochLib_Logo_Auto.svg
@@ -40,4 +40,4 @@
style="-inkscape-font-specification:'Plus Jakarta Sans Semi-Bold';stroke-width:0.157377"
d="M -6.2972931,202.17036 V 170.90473 H 14.308651 v 4.40656 H -1.3451516 v 8.981 H 13.469305 v 4.40657 H -1.3451516 v 9.06494 H 14.308651 v 4.40656 z m 24.7606991,8.39346 v -31.09777 h 4.532469 v 4.70034 L 22.4503,183.07524 q 1.259019,-1.93049 3.357384,-3.02164 2.098365,-1.09115 4.784272,-1.09115 3.231482,0 5.791488,1.55279 2.601972,1.55279 4.070828,4.23869 1.510822,2.68591 1.510822,6.0433 0,3.35738 -1.510822,6.04329 -1.468856,2.6859 -4.028861,4.28066 -2.560006,1.55279 -5.833455,1.55279 -2.64394,0 -4.826239,-1.09115 -2.140333,-1.09115 -3.315417,-3.14755 l 0.755411,-0.88131 v 13.00986 z m 11.66691,-12.29641 q 2.01443,0 3.56722,-0.96525 1.55279,-0.96525 2.434104,-2.64394 0.92328,-1.72066 0.92328,-3.86099 0,-2.14034 -0.92328,-3.81903 -0.881314,-1.67869 -2.434104,-2.64394 -1.55279,-0.96525 -3.56722,-0.96525 -1.972463,0 -3.567221,0.96525 -1.55279,0.96525 -2.476071,2.64394 -0.881313,1.67869 -0.881313,3.81903 0,2.14033 0.881313,3.86099 0.923281,1.67869 2.476071,2.64394 1.594758,0.96525 3.567221,0.96525 z m 27.404651,4.40656 q -3.273449,0 -6.001324,-1.55279 -2.685907,-1.55279 -4.280665,-4.2387 -1.594757,-2.6859 -1.594757,-6.08525 0,-3.44132 1.594757,-6.08526 1.594758,-2.68591 4.280665,-4.19673 2.685907,-1.55279 6.001324,-1.55279 3.357384,0 6.001324,1.55279 2.685907,1.51082 4.238697,4.19673 1.594758,2.64394 1.594758,6.08526 0,3.44131 -1.594758,6.12722 -1.594757,2.68591 -4.280664,4.2387 -2.685908,1.51082 -5.959357,1.51082 z m 0,-4.40656 q 2.01443,0 3.567221,-0.96525 1.55279,-0.96525 2.434103,-2.64394 0.923281,-1.72066 0.923281,-3.86099 0,-2.14034 -0.923281,-3.81903 -0.881313,-1.67869 -2.434103,-2.64394 -1.552791,-0.96525 -3.567221,-0.96525 -1.972463,0 -3.56722,0.96525 -1.552791,0.96525 -2.476071,2.64394 -0.881314,1.67869 -0.881314,3.81903 0,2.14033 0.881314,3.86099 0.92328,1.67869 2.476071,2.64394 1.594757,0.96525 3.56722,0.96525 z m 27.236775,4.40656 q -3.399351,0 -6.043291,-1.55279 -2.601973,-1.59476 -4.154763,-4.28066 -1.510823,-2.68591 -1.510823,-6.08526 0,-3.35739 1.510823,-6.04329 1.510823,-2.68591 4.154763,-4.19673 2.64394,-1.55279 6.043291,-1.55279 2.308201,0 4.322632,0.83934 2.01443,0.79738 3.483286,2.22427 1.510823,1.42689 2.224267,3.31542 l -4.154763,1.93049 q -0.713444,-1.76262 -2.308202,-2.81181 -1.55279,-1.09115 -3.56722,-1.09115 -1.930496,0 -3.483286,0.96525 -1.510823,0.92328 -2.392136,2.64394 -0.881314,1.67869 -0.881314,3.81903 0,2.14033 0.881314,3.86099 0.881313,1.67869 2.392136,2.64394 1.55279,0.96525 3.483286,0.96525 2.056398,0 3.56722,-1.04919 1.552791,-1.09115 2.308202,-2.89574 l 4.154763,1.97246 q -0.671477,1.8046 -2.1823,3.27345 -1.468855,1.42689 -3.483286,2.26624 -2.01443,0.83934 -4.364599,0.83934 z m 14.646624,-0.50361 v -31.76924 h 4.742304 v 13.51347 l -0.75541,-0.58754 q 0.83935,-2.14034 2.68591,-3.23149 1.84656,-1.13311 4.28066,-1.13311 2.51804,0 4.44854,1.09115 1.93049,1.09115 3.02164,3.02164 1.09115,1.9305 1.09115,4.40657 v 14.68855 h -4.70034 V 188.7828 q 0,-1.72066 -0.67147,-2.89575 -0.62951,-1.21705 -1.76263,-1.84656 -1.13312,-0.67148 -2.60197,-0.67148 -1.42689,0 -2.60198,0.67148 -1.13311,0.62951 -1.80459,1.84656 -0.62951,1.21705 -0.62951,2.89575 v 13.38756 z m 31.391564,0.50361 q -2.30821,0 -4.07083,-0.79738 -1.72066,-0.83934 -2.68591,-2.26623 -0.96525,-1.46886 -0.96525,-3.44132 0,-1.84656 0.79738,-3.31542 0.83935,-1.46885 2.56001,-2.47607 1.72066,-1.00721 4.32263,-1.42689 l 7.88985,-1.30098 v 3.73509 l -6.96657,1.21705 q -1.88853,0.33574 -2.76984,1.21705 -0.88132,0.83935 -0.88132,2.1823 0,1.30099 0.96525,2.14033 1.00722,0.79738 2.56001,0.79738 1.93049,0 3.35738,-0.83934 1.46886,-0.83935 2.26624,-2.22427 0.79737,-1.42689 0.79737,-3.14755 v -5.83345 q 0,-1.67869 -1.25901,-2.72788 -1.21706,-1.09115 -3.27345,-1.09115 -1.88853,0 -3.31542,1.00722 -1.38492,0.96525 -2.0564,2.51804 l -3.94492,-1.97247 q 0.6295,-1.67869 2.05639,-2.93771 1.42689,-1.30098 3.31542,-2.01443 1.9305,-0.71344 4.07083,-0.71344 2.68591,0 4.7423,1.00721 2.09837,1.00722 3.23149,2.81181 1.17508,1.76263 1.17508,4.1128 v 15.27609 h -4.53247 v -4.11279 l 0.96525,0.1259 q -0.79738,1.38492 -2.0564,2.39214 -1.21705,1.00721 -2.81181,1.55279 -1.55279,0.54557 -3.48328,0.54557 z m 17.45836,-0.50361 v -31.76924 h 4.7423 v 31.76924 z m 12.29646,9.27478 q -0.83935,0 -1.6787,-0.16787 -0.79737,-0.1259 -1.46885,-0.41967 v -4.07083 q 0.50361,0.1259 1.21705,0.20983 0.71345,0.12591 1.38492,0.12591 1.97247,0 2.89575,-0.88132 0.96524,-0.83934 1.72066,-2.64394 l 1.51082,-3.44132 -0.0839,3.9869 -9.82035,-24.67678 h 5.07805 l 7.26034,18.92726 h -1.72066 l 7.21838,-18.92726 h 5.12001 l -10.03019,25.18039 q -0.79738,2.01443 -2.01443,3.52525 -1.17508,1.55279 -2.81181,2.39213 -1.59475,0.88132 -3.77705,0.88132 z m 30.25842,-8.77117 q -3.48328,0 -6.12722,-1.72066 -2.60198,-1.72066 -3.65116,-4.6164 l 3.65116,-1.72066 q 0.92328,1.9305 2.51803,3.06361 1.63673,1.13312 3.60919,1.13312 1.67869,0 2.72788,-0.75541 1.04918,-0.75541 1.04918,-2.0564 0,-0.83934 -0.46164,-1.34295 -0.46164,-0.54558 -1.17509,-0.88132 -0.67147,-0.33573 -1.38492,-0.5036 l -3.56722,-1.00722 q -2.93771,-0.83934 -4.40656,-2.51804 -1.42689,-1.72066 -1.42689,-3.98689 0,-2.0564 1.04918,-3.56722 1.04918,-1.55279 2.89574,-2.39214 1.84656,-0.83934 4.15477,-0.83934 3.10558,0 5.53968,1.55279 2.4341,1.51082 3.44132,4.23869 l -3.65116,1.72066 q -0.67147,-1.63672 -2.14033,-2.60197 -1.42689,-0.96525 -3.23148,-0.96525 -1.55279,0 -2.47607,0.75541 -0.92328,0.71345 -0.92328,1.88853 0,0.79738 0.41967,1.34296 0.41967,0.5036 1.09115,0.83934 0.67148,0.29377 1.38492,0.50361 l 3.69312,1.09115 q 2.81181,0.79738 4.32264,2.51804 1.51082,1.67869 1.51082,4.02886 0,2.01443 -1.09115,3.56722 -1.04918,1.51082 -2.93771,2.39214 -1.88853,0.83934 -4.40657,0.83934 z m 22.91412,-0.2518 q -3.56722,0 -5.53969,-2.01443 -1.97246,-2.01443 -1.97246,-5.66559 v -10.99543 h -3.98689 v -4.28067 h 0.62951 q 1.59475,0 2.47607,-0.92328 0.88131,-0.92328 0.88131,-2.51803 v -1.76263 h 4.74231 v 5.20394 h 5.16197 v 4.28067 h -5.16197 v 10.7856 q 0,1.17508 0.3777,2.01443 0.37771,0.79737 1.21705,1.25902 0.83935,0.41967 2.1823,0.41967 0.33574,0 0.75541,-0.042 0.41968,-0.042 0.79738,-0.0839 v 4.07082 q -0.58754,0.0839 -1.30098,0.16787 -0.71345,0.0839 -1.25902,0.0839 z"
id="text14"
- aria-label="Epochalyst" />
+ aria-label="EpochLib" />
diff --git a/assets/Epochalyst_Logo_Dark.svg b/assets/EpochLib_Logo_Dark.svg
similarity index 99%
rename from assets/Epochalyst_Logo_Dark.svg
rename to assets/EpochLib_Logo_Dark.svg
index 4e00cb6..fb538a2 100644
--- a/assets/Epochalyst_Logo_Dark.svg
+++ b/assets/EpochLib_Logo_Dark.svg
@@ -32,4 +32,4 @@
style="font-weight:600;font-size:41.9673px;line-height:0;font-family:'Plus Jakarta Sans';-inkscape-font-specification:'Plus Jakarta Sans Semi-Bold';stroke-width:0.157377"
d="M -6.2972931,202.17036 V 170.90473 H 14.308651 v 4.40656 H -1.3451516 v 8.981 H 13.469305 v 4.40657 H -1.3451516 v 9.06494 H 14.308651 v 4.40656 z m 24.7606991,8.39346 v -31.09777 h 4.532469 v 4.70034 L 22.4503,183.07524 q 1.259019,-1.93049 3.357384,-3.02164 2.098365,-1.09115 4.784272,-1.09115 3.231482,0 5.791488,1.55279 2.601972,1.55279 4.070828,4.23869 1.510822,2.68591 1.510822,6.0433 0,3.35738 -1.510822,6.04329 -1.468856,2.6859 -4.028861,4.28066 -2.560006,1.55279 -5.833455,1.55279 -2.64394,0 -4.826239,-1.09115 -2.140333,-1.09115 -3.315417,-3.14755 l 0.755411,-0.88131 v 13.00986 z m 11.66691,-12.29641 q 2.01443,0 3.56722,-0.96525 1.55279,-0.96525 2.434104,-2.64394 0.92328,-1.72066 0.92328,-3.86099 0,-2.14034 -0.92328,-3.81903 -0.881314,-1.67869 -2.434104,-2.64394 -1.55279,-0.96525 -3.56722,-0.96525 -1.972463,0 -3.567221,0.96525 -1.55279,0.96525 -2.476071,2.64394 -0.881313,1.67869 -0.881313,3.81903 0,2.14033 0.881313,3.86099 0.923281,1.67869 2.476071,2.64394 1.594758,0.96525 3.567221,0.96525 z m 27.404651,4.40656 q -3.273449,0 -6.001324,-1.55279 -2.685907,-1.55279 -4.280665,-4.2387 -1.594757,-2.6859 -1.594757,-6.08525 0,-3.44132 1.594757,-6.08526 1.594758,-2.68591 4.280665,-4.19673 2.685907,-1.55279 6.001324,-1.55279 3.357384,0 6.001324,1.55279 2.685907,1.51082 4.238697,4.19673 1.594758,2.64394 1.594758,6.08526 0,3.44131 -1.594758,6.12722 -1.594757,2.68591 -4.280664,4.2387 -2.685908,1.51082 -5.959357,1.51082 z m 0,-4.40656 q 2.01443,0 3.567221,-0.96525 1.55279,-0.96525 2.434103,-2.64394 0.923281,-1.72066 0.923281,-3.86099 0,-2.14034 -0.923281,-3.81903 -0.881313,-1.67869 -2.434103,-2.64394 -1.552791,-0.96525 -3.567221,-0.96525 -1.972463,0 -3.56722,0.96525 -1.552791,0.96525 -2.476071,2.64394 -0.881314,1.67869 -0.881314,3.81903 0,2.14033 0.881314,3.86099 0.92328,1.67869 2.476071,2.64394 1.594757,0.96525 3.56722,0.96525 z m 27.236775,4.40656 q -3.399351,0 -6.043291,-1.55279 -2.601973,-1.59476 -4.154763,-4.28066 -1.510823,-2.68591 -1.510823,-6.08526 0,-3.35739 1.510823,-6.04329 1.510823,-2.68591 4.154763,-4.19673 2.64394,-1.55279 6.043291,-1.55279 2.308201,0 4.322632,0.83934 2.01443,0.79738 3.483286,2.22427 1.510823,1.42689 2.224267,3.31542 l -4.154763,1.93049 q -0.713444,-1.76262 -2.308202,-2.81181 -1.55279,-1.09115 -3.56722,-1.09115 -1.930496,0 -3.483286,0.96525 -1.510823,0.92328 -2.392136,2.64394 -0.881314,1.67869 -0.881314,3.81903 0,2.14033 0.881314,3.86099 0.881313,1.67869 2.392136,2.64394 1.55279,0.96525 3.483286,0.96525 2.056398,0 3.56722,-1.04919 1.552791,-1.09115 2.308202,-2.89574 l 4.154763,1.97246 q -0.671477,1.8046 -2.1823,3.27345 -1.468855,1.42689 -3.483286,2.26624 -2.01443,0.83934 -4.364599,0.83934 z m 14.646624,-0.50361 v -31.76924 h 4.742304 v 13.51347 l -0.75541,-0.58754 q 0.83935,-2.14034 2.68591,-3.23149 1.84656,-1.13311 4.28066,-1.13311 2.51804,0 4.44854,1.09115 1.93049,1.09115 3.02164,3.02164 1.09115,1.9305 1.09115,4.40657 v 14.68855 h -4.70034 V 188.7828 q 0,-1.72066 -0.67147,-2.89575 -0.62951,-1.21705 -1.76263,-1.84656 -1.13312,-0.67148 -2.60197,-0.67148 -1.42689,0 -2.60198,0.67148 -1.13311,0.62951 -1.80459,1.84656 -0.62951,1.21705 -0.62951,2.89575 v 13.38756 z m 31.391564,0.50361 q -2.30821,0 -4.07083,-0.79738 -1.72066,-0.83934 -2.68591,-2.26623 -0.96525,-1.46886 -0.96525,-3.44132 0,-1.84656 0.79738,-3.31542 0.83935,-1.46885 2.56001,-2.47607 1.72066,-1.00721 4.32263,-1.42689 l 7.88985,-1.30098 v 3.73509 l -6.96657,1.21705 q -1.88853,0.33574 -2.76984,1.21705 -0.88132,0.83935 -0.88132,2.1823 0,1.30099 0.96525,2.14033 1.00722,0.79738 2.56001,0.79738 1.93049,0 3.35738,-0.83934 1.46886,-0.83935 2.26624,-2.22427 0.79737,-1.42689 0.79737,-3.14755 v -5.83345 q 0,-1.67869 -1.25901,-2.72788 -1.21706,-1.09115 -3.27345,-1.09115 -1.88853,0 -3.31542,1.00722 -1.38492,0.96525 -2.0564,2.51804 l -3.94492,-1.97247 q 0.6295,-1.67869 2.05639,-2.93771 1.42689,-1.30098 3.31542,-2.01443 1.9305,-0.71344 4.07083,-0.71344 2.68591,0 4.7423,1.00721 2.09837,1.00722 3.23149,2.81181 1.17508,1.76263 1.17508,4.1128 v 15.27609 h -4.53247 v -4.11279 l 0.96525,0.1259 q -0.79738,1.38492 -2.0564,2.39214 -1.21705,1.00721 -2.81181,1.55279 -1.55279,0.54557 -3.48328,0.54557 z m 17.45836,-0.50361 v -31.76924 h 4.7423 v 31.76924 z m 12.29646,9.27478 q -0.83935,0 -1.6787,-0.16787 -0.79737,-0.1259 -1.46885,-0.41967 v -4.07083 q 0.50361,0.1259 1.21705,0.20983 0.71345,0.12591 1.38492,0.12591 1.97247,0 2.89575,-0.88132 0.96524,-0.83934 1.72066,-2.64394 l 1.51082,-3.44132 -0.0839,3.9869 -9.82035,-24.67678 h 5.07805 l 7.26034,18.92726 h -1.72066 l 7.21838,-18.92726 h 5.12001 l -10.03019,25.18039 q -0.79738,2.01443 -2.01443,3.52525 -1.17508,1.55279 -2.81181,2.39213 -1.59475,0.88132 -3.77705,0.88132 z m 30.25842,-8.77117 q -3.48328,0 -6.12722,-1.72066 -2.60198,-1.72066 -3.65116,-4.6164 l 3.65116,-1.72066 q 0.92328,1.9305 2.51803,3.06361 1.63673,1.13312 3.60919,1.13312 1.67869,0 2.72788,-0.75541 1.04918,-0.75541 1.04918,-2.0564 0,-0.83934 -0.46164,-1.34295 -0.46164,-0.54558 -1.17509,-0.88132 -0.67147,-0.33573 -1.38492,-0.5036 l -3.56722,-1.00722 q -2.93771,-0.83934 -4.40656,-2.51804 -1.42689,-1.72066 -1.42689,-3.98689 0,-2.0564 1.04918,-3.56722 1.04918,-1.55279 2.89574,-2.39214 1.84656,-0.83934 4.15477,-0.83934 3.10558,0 5.53968,1.55279 2.4341,1.51082 3.44132,4.23869 l -3.65116,1.72066 q -0.67147,-1.63672 -2.14033,-2.60197 -1.42689,-0.96525 -3.23148,-0.96525 -1.55279,0 -2.47607,0.75541 -0.92328,0.71345 -0.92328,1.88853 0,0.79738 0.41967,1.34296 0.41967,0.5036 1.09115,0.83934 0.67148,0.29377 1.38492,0.50361 l 3.69312,1.09115 q 2.81181,0.79738 4.32264,2.51804 1.51082,1.67869 1.51082,4.02886 0,2.01443 -1.09115,3.56722 -1.04918,1.51082 -2.93771,2.39214 -1.88853,0.83934 -4.40657,0.83934 z m 22.91412,-0.2518 q -3.56722,0 -5.53969,-2.01443 -1.97246,-2.01443 -1.97246,-5.66559 v -10.99543 h -3.98689 v -4.28067 h 0.62951 q 1.59475,0 2.47607,-0.92328 0.88131,-0.92328 0.88131,-2.51803 v -1.76263 h 4.74231 v 5.20394 h 5.16197 v 4.28067 h -5.16197 v 10.7856 q 0,1.17508 0.3777,2.01443 0.37771,0.79737 1.21705,1.25902 0.83935,0.41967 2.1823,0.41967 0.33574,0 0.75541,-0.042 0.41968,-0.042 0.79738,-0.0839 v 4.07082 q -0.58754,0.0839 -1.30098,0.16787 -0.71345,0.0839 -1.25902,0.0839 z"
id="text14"
- aria-label="Epochalyst" />
+ aria-label="EpochLib" />
diff --git a/assets/Epochalyst_Logo_Light.svg b/assets/EpochLib_Logo_Light.svg
similarity index 99%
rename from assets/Epochalyst_Logo_Light.svg
rename to assets/EpochLib_Logo_Light.svg
index 4defe20..e959efe 100644
--- a/assets/Epochalyst_Logo_Light.svg
+++ b/assets/EpochLib_Logo_Light.svg
@@ -33,4 +33,4 @@
style="font-weight:600;font-size:41.9673px;line-height:0;font-family:'Plus Jakarta Sans';-inkscape-font-specification:'Plus Jakarta Sans Semi-Bold';stroke-width:0.157377"
d="M -6.2972931,202.17036 V 170.90473 H 14.308651 v 4.40656 H -1.3451516 v 8.981 H 13.469305 v 4.40657 H -1.3451516 v 9.06494 H 14.308651 v 4.40656 z m 24.7606991,8.39346 v -31.09777 h 4.532469 v 4.70034 L 22.4503,183.07524 q 1.259019,-1.93049 3.357384,-3.02164 2.098365,-1.09115 4.784272,-1.09115 3.231482,0 5.791488,1.55279 2.601972,1.55279 4.070828,4.23869 1.510822,2.68591 1.510822,6.0433 0,3.35738 -1.510822,6.04329 -1.468856,2.6859 -4.028861,4.28066 -2.560006,1.55279 -5.833455,1.55279 -2.64394,0 -4.826239,-1.09115 -2.140333,-1.09115 -3.315417,-3.14755 l 0.755411,-0.88131 v 13.00986 z m 11.66691,-12.29641 q 2.01443,0 3.56722,-0.96525 1.55279,-0.96525 2.434104,-2.64394 0.92328,-1.72066 0.92328,-3.86099 0,-2.14034 -0.92328,-3.81903 -0.881314,-1.67869 -2.434104,-2.64394 -1.55279,-0.96525 -3.56722,-0.96525 -1.972463,0 -3.567221,0.96525 -1.55279,0.96525 -2.476071,2.64394 -0.881313,1.67869 -0.881313,3.81903 0,2.14033 0.881313,3.86099 0.923281,1.67869 2.476071,2.64394 1.594758,0.96525 3.567221,0.96525 z m 27.404651,4.40656 q -3.273449,0 -6.001324,-1.55279 -2.685907,-1.55279 -4.280665,-4.2387 -1.594757,-2.6859 -1.594757,-6.08525 0,-3.44132 1.594757,-6.08526 1.594758,-2.68591 4.280665,-4.19673 2.685907,-1.55279 6.001324,-1.55279 3.357384,0 6.001324,1.55279 2.685907,1.51082 4.238697,4.19673 1.594758,2.64394 1.594758,6.08526 0,3.44131 -1.594758,6.12722 -1.594757,2.68591 -4.280664,4.2387 -2.685908,1.51082 -5.959357,1.51082 z m 0,-4.40656 q 2.01443,0 3.567221,-0.96525 1.55279,-0.96525 2.434103,-2.64394 0.923281,-1.72066 0.923281,-3.86099 0,-2.14034 -0.923281,-3.81903 -0.881313,-1.67869 -2.434103,-2.64394 -1.552791,-0.96525 -3.567221,-0.96525 -1.972463,0 -3.56722,0.96525 -1.552791,0.96525 -2.476071,2.64394 -0.881314,1.67869 -0.881314,3.81903 0,2.14033 0.881314,3.86099 0.92328,1.67869 2.476071,2.64394 1.594757,0.96525 3.56722,0.96525 z m 27.236775,4.40656 q -3.399351,0 -6.043291,-1.55279 -2.601973,-1.59476 -4.154763,-4.28066 -1.510823,-2.68591 -1.510823,-6.08526 0,-3.35739 1.510823,-6.04329 1.510823,-2.68591 4.154763,-4.19673 2.64394,-1.55279 6.043291,-1.55279 2.308201,0 4.322632,0.83934 2.01443,0.79738 3.483286,2.22427 1.510823,1.42689 2.224267,3.31542 l -4.154763,1.93049 q -0.713444,-1.76262 -2.308202,-2.81181 -1.55279,-1.09115 -3.56722,-1.09115 -1.930496,0 -3.483286,0.96525 -1.510823,0.92328 -2.392136,2.64394 -0.881314,1.67869 -0.881314,3.81903 0,2.14033 0.881314,3.86099 0.881313,1.67869 2.392136,2.64394 1.55279,0.96525 3.483286,0.96525 2.056398,0 3.56722,-1.04919 1.552791,-1.09115 2.308202,-2.89574 l 4.154763,1.97246 q -0.671477,1.8046 -2.1823,3.27345 -1.468855,1.42689 -3.483286,2.26624 -2.01443,0.83934 -4.364599,0.83934 z m 14.646624,-0.50361 v -31.76924 h 4.742304 v 13.51347 l -0.75541,-0.58754 q 0.83935,-2.14034 2.68591,-3.23149 1.84656,-1.13311 4.28066,-1.13311 2.51804,0 4.44854,1.09115 1.93049,1.09115 3.02164,3.02164 1.09115,1.9305 1.09115,4.40657 v 14.68855 h -4.70034 V 188.7828 q 0,-1.72066 -0.67147,-2.89575 -0.62951,-1.21705 -1.76263,-1.84656 -1.13312,-0.67148 -2.60197,-0.67148 -1.42689,0 -2.60198,0.67148 -1.13311,0.62951 -1.80459,1.84656 -0.62951,1.21705 -0.62951,2.89575 v 13.38756 z m 31.391564,0.50361 q -2.30821,0 -4.07083,-0.79738 -1.72066,-0.83934 -2.68591,-2.26623 -0.96525,-1.46886 -0.96525,-3.44132 0,-1.84656 0.79738,-3.31542 0.83935,-1.46885 2.56001,-2.47607 1.72066,-1.00721 4.32263,-1.42689 l 7.88985,-1.30098 v 3.73509 l -6.96657,1.21705 q -1.88853,0.33574 -2.76984,1.21705 -0.88132,0.83935 -0.88132,2.1823 0,1.30099 0.96525,2.14033 1.00722,0.79738 2.56001,0.79738 1.93049,0 3.35738,-0.83934 1.46886,-0.83935 2.26624,-2.22427 0.79737,-1.42689 0.79737,-3.14755 v -5.83345 q 0,-1.67869 -1.25901,-2.72788 -1.21706,-1.09115 -3.27345,-1.09115 -1.88853,0 -3.31542,1.00722 -1.38492,0.96525 -2.0564,2.51804 l -3.94492,-1.97247 q 0.6295,-1.67869 2.05639,-2.93771 1.42689,-1.30098 3.31542,-2.01443 1.9305,-0.71344 4.07083,-0.71344 2.68591,0 4.7423,1.00721 2.09837,1.00722 3.23149,2.81181 1.17508,1.76263 1.17508,4.1128 v 15.27609 h -4.53247 v -4.11279 l 0.96525,0.1259 q -0.79738,1.38492 -2.0564,2.39214 -1.21705,1.00721 -2.81181,1.55279 -1.55279,0.54557 -3.48328,0.54557 z m 17.45836,-0.50361 v -31.76924 h 4.7423 v 31.76924 z m 12.29646,9.27478 q -0.83935,0 -1.6787,-0.16787 -0.79737,-0.1259 -1.46885,-0.41967 v -4.07083 q 0.50361,0.1259 1.21705,0.20983 0.71345,0.12591 1.38492,0.12591 1.97247,0 2.89575,-0.88132 0.96524,-0.83934 1.72066,-2.64394 l 1.51082,-3.44132 -0.0839,3.9869 -9.82035,-24.67678 h 5.07805 l 7.26034,18.92726 h -1.72066 l 7.21838,-18.92726 h 5.12001 l -10.03019,25.18039 q -0.79738,2.01443 -2.01443,3.52525 -1.17508,1.55279 -2.81181,2.39213 -1.59475,0.88132 -3.77705,0.88132 z m 30.25842,-8.77117 q -3.48328,0 -6.12722,-1.72066 -2.60198,-1.72066 -3.65116,-4.6164 l 3.65116,-1.72066 q 0.92328,1.9305 2.51803,3.06361 1.63673,1.13312 3.60919,1.13312 1.67869,0 2.72788,-0.75541 1.04918,-0.75541 1.04918,-2.0564 0,-0.83934 -0.46164,-1.34295 -0.46164,-0.54558 -1.17509,-0.88132 -0.67147,-0.33573 -1.38492,-0.5036 l -3.56722,-1.00722 q -2.93771,-0.83934 -4.40656,-2.51804 -1.42689,-1.72066 -1.42689,-3.98689 0,-2.0564 1.04918,-3.56722 1.04918,-1.55279 2.89574,-2.39214 1.84656,-0.83934 4.15477,-0.83934 3.10558,0 5.53968,1.55279 2.4341,1.51082 3.44132,4.23869 l -3.65116,1.72066 q -0.67147,-1.63672 -2.14033,-2.60197 -1.42689,-0.96525 -3.23148,-0.96525 -1.55279,0 -2.47607,0.75541 -0.92328,0.71345 -0.92328,1.88853 0,0.79738 0.41967,1.34296 0.41967,0.5036 1.09115,0.83934 0.67148,0.29377 1.38492,0.50361 l 3.69312,1.09115 q 2.81181,0.79738 4.32264,2.51804 1.51082,1.67869 1.51082,4.02886 0,2.01443 -1.09115,3.56722 -1.04918,1.51082 -2.93771,2.39214 -1.88853,0.83934 -4.40657,0.83934 z m 22.91412,-0.2518 q -3.56722,0 -5.53969,-2.01443 -1.97246,-2.01443 -1.97246,-5.66559 v -10.99543 h -3.98689 v -4.28067 h 0.62951 q 1.59475,0 2.47607,-0.92328 0.88131,-0.92328 0.88131,-2.51803 v -1.76263 h 4.74231 v 5.20394 h 5.16197 v 4.28067 h -5.16197 v 10.7856 q 0,1.17508 0.3777,2.01443 0.37771,0.79737 1.21705,1.25902 0.83935,0.41967 2.1823,0.41967 0.33574,0 0.75541,-0.042 0.41968,-0.042 0.79738,-0.0839 v 4.07082 q -0.58754,0.0839 -1.30098,0.16787 -0.71345,0.0839 -1.25902,0.0839 z"
id="text14"
- aria-label="Epochalyst" />
+ aria-label="EpochLib" />
diff --git a/docs/_static/logo.css b/docs/_static/logo.css
index a267b32..c8e0f3b 100644
--- a/docs/_static/logo.css
+++ b/docs/_static/logo.css
@@ -1,11 +1,11 @@
-.epochalyst-logo {
- content: url('https://raw.githubusercontent.com/TeamEpochGithub/epochalyst/main/assets/Epochalyst_Logo_Dark.svg');
+.epochlib-logo {
+ content: url('https://raw.githubusercontent.com/TeamEpochGithub/epochlib/main/assets/EpochLib.svg');
max-width: 100%;
height: auto;
background-size: contain;
background-repeat: no-repeat;
}
-html.dark .epochalyst-logo {
- content: url('https://raw.githubusercontent.com/TeamEpochGithub/epochalyst/main/assets/Epochalyst_Logo_Light.svg');
+html.dark .epochlib-logo {
+ content: url('https://raw.githubusercontent.com/TeamEpochGithub/epochlib/main/assets/EpochLib.svg');
}
diff --git a/docs/conf.py b/docs/conf.py
index 05594f5..25af310 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -9,11 +9,11 @@
from pathlib import Path
from typing import Any, Final
-REPO_URL: Final[str] = "https://github.com/TeamEpochGithub/epochalyst"
+REPO_URL: Final[str] = "https://github.com/TeamEpochGithub/epochlib"
sys.path.insert(0, Path("../..").resolve().as_posix())
-project: Final[str] = "Epochalyst"
+project: Final[str] = "EpochLib"
copyright: Final[str] = "2024, Team Epoch." # noqa: A001
author: Final[str] = "Team Epoch"
diff --git a/docs/index.rst b/docs/index.rst
index 36484c2..07c8421 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,6 +1,6 @@
.. raw:: html
-
+
-------------------
@@ -22,7 +22,7 @@ API
:toctree: _autosummary
:recursive:
- epochalyst.epochalyst
+ epochlib.epochlib
Indices and tables
==================
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 0b6c140..b84252b 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,7 +1,7 @@
-sphinx==7.2.6
-sphinx-autodoc-typehints==2.0.0
-sphinxawesome-theme==5.1.3
-myst-parser==2.0.0
-pygit2==1.14.1
+sphinx==8.1.3
+sphinx-autodoc-typehints==2.5.0
+sphinxawesome-theme==5.3.2
+myst-parser==4.0.0
+pygit2==1.16.0
agogos==0.4.0
-torch==2.2.2
+torch==2.5.1
diff --git a/epochalyst/caching/cacher.py b/epochalyst/caching/cacher.py
deleted file mode 100644
index 6a4dcc3..0000000
--- a/epochalyst/caching/cacher.py
+++ /dev/null
@@ -1,356 +0,0 @@
-"""The cacher module contains the Cacher class."""
-
-import glob
-import os
-import pickle
-import sys
-from typing import Any, Literal, TypedDict
-
-from epochalyst.logging import Logger
-
-try:
- import dask.array as da
- import dask.dataframe as dd
-except ImportError:
- """User doesn't require these packages"""
-
-try:
- import numpy as np
-except ImportError:
- """User doesn't require these packages"""
-
-try:
- import pandas as pd
-except ImportError:
- """User doesn't require these packages"""
-
-try:
- import polars as pl
-except ImportError:
- """User doesn't require these packages"""
-
-if sys.version_info < (3, 11): # pragma: no cover ( bool: # Check if the cache exists
-
- def _get_cache(name: str, cache_args: _CacheArgs | None = None) -> Any: # Load the cache
-
- def _store_cache(name: str, data: Any, cache_args: _CacheArgs | None = None) -> None: # Store data
- """
-
- def cache_exists(self, name: str, cache_args: CacheArgs | None = None) -> bool:
- """Check if the cache exists.
-
- :param cache_args: The cache arguments.
- :return: True if the cache exists, False otherwise.
- """
- if not cache_args:
- return False
-
- # Check if cache_args contains storage type and storage path
- if "storage_type" not in cache_args or "storage_path" not in cache_args:
- raise ValueError("cache_args must contain storage_type and storage_path")
-
- storage_type = cache_args["storage_type"]
- storage_path = cache_args["storage_path"]
-
- self.log_to_debug(
- f"Checking if cache exists for type: {storage_type} and path: {storage_path}",
- )
-
- # If storage path does not end a slash, add it
- if storage_path[-1] != "/":
- storage_path += "/"
-
- # Check if path exists
- path_exists = False
-
- if storage_type == ".npy":
- path_exists = os.path.exists(storage_path + name + ".npy")
- elif storage_type == ".parquet":
- path_exists = os.path.exists(storage_path + name + ".parquet")
- elif storage_type == ".csv":
- # Check if the file exists or if there are any parts inside the folder
- path_exists = os.path.exists(storage_path + name + ".csv") or glob.glob(storage_path + name + "/*.part") != []
- elif storage_type == ".npy_stack":
- path_exists = os.path.exists(storage_path + name)
- elif storage_type == ".pkl":
- path_exists = os.path.exists(storage_path + name + ".pkl")
-
- self.log_to_debug(
- f"Cache exists is {path_exists} for type: {storage_type} and path: {storage_path}",
- )
-
- return path_exists
-
- def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: # noqa: ANN401 C901 PLR0911 PLR0912
- """Load the cache.
-
- :param name: The name of the cache.
- :param cache_args: The cache arguments.
- :return: The cached data.
- """
- # Check if cache_args is empty
- if not cache_args:
- raise ValueError("cache_args is empty")
-
- # Check if storage type, storage_path and output_data_type are in cache_args
- if "storage_type" not in cache_args or "storage_path" not in cache_args or "output_data_type" not in cache_args:
- raise ValueError(
- "cache_args must contain storage_type, storage_path and output_data_type",
- )
-
- storage_type = cache_args["storage_type"]
- storage_path = cache_args["storage_path"]
- output_data_type = cache_args["output_data_type"]
- read_args = cache_args.get("read_args", {})
-
- # If storage path does not end a slash, add it
- if storage_path[-1] != "/":
- storage_path += "/"
-
- # Load the cache
- if storage_type == ".npy":
- # Check if output_data_type is supported and load cache to output_data_type
- self.log_to_debug(f"Loading .npy file from {storage_path + name}")
- if output_data_type == "numpy_array":
- return np.load(storage_path + name + ".npy", **read_args)
- if output_data_type == "dask_array":
- return da.from_array(np.load(storage_path + name + ".npy"), **read_args)
-
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for loading .npy file.",
- )
- raise ValueError(
- "output_data_type must be numpy_array or dask_array, other types not supported yet",
- )
- if storage_type == ".parquet":
- # Check if output_data_type is supported and load cache to output_data_type
- self.log_to_debug(f"Loading .parquet file from {storage_path + name}")
- if output_data_type == "pandas_dataframe":
- return pd.read_parquet(storage_path + name + ".parquet", **read_args)
- if output_data_type == "dask_dataframe":
- return dd.read_parquet(storage_path + name + ".parquet", **read_args)
- if output_data_type == "numpy_array":
- return pd.read_parquet(
- storage_path + name + ".parquet",
- **read_args,
- ).to_numpy()
- if output_data_type == "dask_array":
- return dd.read_parquet(
- storage_path + name + ".parquet",
- **read_args,
- ).to_dask_array()
- if output_data_type == "polars_dataframe":
- return pl.read_parquet(storage_path + name + ".parquet", **read_args)
-
- self.log_to_debug( # type: ignore[unreachable]
- f"Invalid output data type: {output_data_type}, for loading .parquet file.",
- )
- raise ValueError(
- "output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
- )
- if storage_type == ".csv":
- # Check if output_data_type is supported and load cache to output_data_type
- self.log_to_debug(f"Loading .csv file from {storage_path + name}")
- if output_data_type == "pandas_dataframe":
- return pd.read_csv(storage_path + name + ".csv", **read_args)
- if output_data_type == "dask_dataframe":
- return dd.read_csv(storage_path + name + "/*.part", **read_args)
- if output_data_type == "polars_dataframe":
- return pl.read_csv(storage_path + name + ".csv", **read_args)
-
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for loading .csv file.",
- )
- raise ValueError(
- "output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe, other types not supported yet",
- )
- if storage_type == ".npy_stack":
- # Check if output_data_type is supported and load cache to output_data_type
- self.log_to_debug(f"Loading .npy_stack file from {storage_path + name}")
- if output_data_type == "dask_array":
- return da.from_npy_stack(storage_path + name, **read_args)
-
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for loading .npy_stack file.",
- )
- raise ValueError(
- "output_data_type must be dask_array, other types not supported yet",
- )
- if storage_type == ".pkl":
- # Load the pickle file
- self.log_to_debug(
- f"Loading pickle file from {storage_path + name + '.pkl'}",
- )
- with open(storage_path + name + ".pkl", "rb") as file:
- return pickle.load(file, **read_args) # noqa: S301
-
- self.log_to_debug(f"Invalid storage type: {storage_type}") # type: ignore[unreachable]
- raise ValueError(
- "storage_type must be .npy, .parquet, .csv, or .npy_stack, other types not supported yet",
- )
-
- def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None) -> None: # noqa: ANN401 C901 PLR0915 PLR0912
- """Store one set of data.
-
- :param name: The name of the cache.
- :param data: The data to store.
- :param cache_args: The cache arguments.
- """
- # Check if cache_args is empty
- if not cache_args:
- raise ValueError("cache_args is empty")
-
- # Check if storage type, storage_path and output_data_type are in cache_args
- if "storage_type" not in cache_args or "storage_path" not in cache_args or "output_data_type" not in cache_args:
- raise ValueError(
- "cache_args must contain storage_type, storage_path and output_data_type",
- )
-
- storage_type = cache_args["storage_type"]
- storage_path = cache_args["storage_path"]
- output_data_type = cache_args["output_data_type"]
- store_args = cache_args.get("store_args", {})
-
- # If storage path does not end a slash, add it
- if storage_path[-1] != "/":
- storage_path += "/"
-
- # Store the cache
- if storage_type == ".npy":
- # Check if output_data_type is supported and store cache to output_data_type
- self.log_to_debug(f"Storing .npy file to {storage_path + name}")
- if output_data_type == "numpy_array":
- np.save(storage_path + name + ".npy", data, **store_args)
- elif output_data_type == "dask_array":
- np.save(storage_path + name + ".npy", data.compute(), **store_args)
- else:
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for storing .npy file.",
- )
- raise ValueError(
- "output_data_type must be numpy_array or dask_array, other types not supported yet",
- )
- elif storage_type == ".parquet":
- # Check if output_data_type is supported and store cache to output_data_type
- self.log_to_debug(f"Storing .parquet file to {storage_path + name}")
- if output_data_type in {"pandas_dataframe", "dask_dataframe"}:
- data.to_parquet(storage_path + name + ".parquet", **store_args)
- elif output_data_type == "numpy_array":
- pd.DataFrame(data).to_parquet(
- storage_path + name + ".parquet",
- **store_args,
- )
- elif output_data_type == "dask_array":
- new_dd = dd.from_dask_array(data)
- new_dd = new_dd.rename(
- columns={col: str(col) for col in new_dd.columns},
- )
- new_dd.to_parquet(storage_path + name + ".parquet", **store_args)
- elif output_data_type == "polars_dataframe":
- data.write_parquet(storage_path + name + ".parquet", **store_args)
- else:
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for storing .parquet file.",
- )
- raise ValueError(
- "output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
- )
- elif storage_type == ".csv":
- # Check if output_data_type is supported and store cache to output_data_type
- self.log_to_debug(f"Storing .csv file to {storage_path + name}")
- if output_data_type == "pandas_dataframe":
- data.to_csv(storage_path + name + ".csv", **({"index": False} | store_args))
- elif output_data_type == "dask_dataframe":
- data.to_csv(storage_path + name, **({"index": False} | store_args))
- elif output_data_type == "polars_dataframe":
- data.write_csv(storage_path + name + ".csv", **store_args)
- else:
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for storing .csv file.",
- )
- raise ValueError(
- "output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe, other types not supported yet",
- )
- elif storage_type == ".npy_stack":
- # Check if output_data_type is supported and store cache to output_data_type
- self.log_to_debug(f"Storing .npy_stack file to {storage_path + name}")
- if output_data_type == "dask_array":
- da.to_npy_stack(storage_path + name, data, **store_args)
- else:
- self.log_to_debug(
- f"Invalid output data type: {output_data_type}, for storing .npy_stack file.",
- )
- raise ValueError(
- "output_data_type must be numpy_array other types not supported yet",
- )
- elif storage_type == ".pkl":
- # Store the pickle file
- self.log_to_debug(f"Storing pickle file to {storage_path + name + '.pkl'}")
- with open(storage_path + name + ".pkl", "wb") as f:
- pickle.dump(
- data,
- f,
- **({"protocol": pickle.HIGHEST_PROTOCOL} | store_args),
- )
- else:
- self.log_to_debug(f"Invalid storage type: {storage_type}") # type: ignore[unreachable]
- raise ValueError(
- "storage_type must be .npy, .parquet, .csv or .npy_stack, other types not supported yet",
- )
diff --git a/epochalyst/__init__.py b/epochlib/__init__.py
similarity index 81%
rename from epochalyst/__init__.py
rename to epochlib/__init__.py
index 64eaef1..6119a4a 100644
--- a/epochalyst/__init__.py
+++ b/epochlib/__init__.py
@@ -1,4 +1,4 @@
-"""The epochalyst package."""
+"""The epochlib package."""
from .ensemble import EnsemblePipeline
from .model import ModelPipeline
diff --git a/epochalyst/caching/__init__.py b/epochlib/caching/__init__.py
similarity index 66%
rename from epochalyst/caching/__init__.py
rename to epochlib/caching/__init__.py
index c4f989f..a22b347 100644
--- a/epochalyst/caching/__init__.py
+++ b/epochlib/caching/__init__.py
@@ -1,4 +1,4 @@
-"""Caching module for epochalyst."""
+"""Caching module for epochlib."""
from .cacher import CacheArgs, Cacher
diff --git a/epochlib/caching/cacher.py b/epochlib/caching/cacher.py
new file mode 100644
index 0000000..b10946b
--- /dev/null
+++ b/epochlib/caching/cacher.py
@@ -0,0 +1,363 @@
+"""The cacher module contains the Cacher class."""
+
+import glob
+import os
+import pickle
+import sys
+from pathlib import Path
+from typing import Any, Callable, Dict, Literal, TypedDict
+
+import numpy as np
+
+from epochlib.logging import Logger
+
+try:
+ import dask.array as da
+ import dask.dataframe as dd
+except ImportError:
+ """User doesn't require these packages"""
+
+
+try:
+ import pandas as pd
+except ImportError:
+ """User doesn't require these packages"""
+
+try:
+ import polars as pl
+except ImportError:
+ """User doesn't require these packages"""
+
+if sys.version_info < (3, 11): # pragma: no cover ( bool: # Check if the cache exists
+
+ def _get_cache(name: str, cache_args: _CacheArgs | None = None) -> Any: # Load the cache
+
+ def _store_cache(name: str, data: Any, cache_args: _CacheArgs | None = None) -> None: # Store data
+ """
+
+ def cache_exists(self, name: str, cache_args: CacheArgs | None = None) -> bool:
+ """Check if the cache exists.
+
+ :param cache_args: The cache arguments.
+ :return: True if the cache exists, False otherwise.
+ """
+ if not cache_args:
+ return False
+
+ # Check if cache_args contains storage type and storage path
+ if "storage_type" not in cache_args or "storage_path" not in cache_args:
+ raise ValueError("cache_args must contain storage_type and storage_path")
+
+ storage_type = cache_args["storage_type"]
+ storage_path = cache_args["storage_path"]
+
+ self.log_to_debug(
+ f"Checking if cache exists for type: {storage_type} and path: {storage_path}",
+ )
+
+ # If storage path does not end a slash, add it
+ if storage_path[-1] != "/":
+ storage_path += "/"
+
+ # Check if path exists
+ path_exists = False
+
+ if storage_type == ".npy":
+ path_exists = os.path.exists(storage_path + name + ".npy")
+ elif storage_type == ".parquet":
+ path_exists = os.path.exists(storage_path + name + ".parquet")
+ elif storage_type == ".csv":
+ # Check if the file exists or if there are any parts inside the folder
+ path_exists = os.path.exists(storage_path + name + ".csv") or glob.glob(storage_path + name + "/*.part") != []
+ elif storage_type == ".npy_stack":
+ path_exists = os.path.exists(storage_path + name)
+ elif storage_type == ".pkl":
+ path_exists = os.path.exists(storage_path + name + ".pkl")
+
+ self.log_to_debug(
+ f"Cache exists is {path_exists} for type: {storage_type} and path: {storage_path}",
+ )
+
+ return path_exists
+
+ def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any:
+ """Load the cache.
+
+ :param name: The name of the cache.
+ :param cache_args: The cache arguments.
+ :return: The cached data.
+ """
+ # Check if cache_args is empty
+ if not cache_args:
+ raise ValueError("cache_args is empty")
+
+ # Check if storage type, storage_path and output_data_type are in cache_args
+ required_keys = ["storage_type", "storage_path", "output_data_type"]
+ for key in required_keys:
+ if key not in cache_args:
+ raise ValueError(f"cache_args must contain {', '.join(required_keys)}")
+
+ if "storage_type" not in cache_args or "storage_path" not in cache_args or "output_data_type" not in cache_args:
+ raise ValueError(
+ "cache_args must contain storage_type, storage_path and output_data_type",
+ )
+
+ storage_type = cache_args["storage_type"]
+ storage_path = Path(cache_args["storage_path"])
+ output_data_type = cache_args["output_data_type"]
+ read_args = cache_args.get("read_args", {})
+
+ load_functions: Dict[str, LoaderFunction] = {
+ ".npy": self._load_npy,
+ ".parquet": self._load_parquet,
+ ".csv": self._load_csv,
+ ".npy_stack": self._load_npy_stack,
+ ".pkl": self._load_pkl,
+ }
+
+ if storage_type in load_functions:
+ return load_functions[storage_type](name, storage_path, output_data_type, read_args)
+
+ self.log_to_debug(f"Invalid storage type: {storage_type}")
+ raise ValueError(
+ "storage_type must be .npy, .parquet, .csv, or .npy_stack, other types not supported yet",
+ )
+
+ def _load_npy(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
+ # Check if output_data_type is supported and load cache to output_data_type
+ self.log_to_debug(f"Loading .npy file from {storage_path / name}")
+ if output_data_type == "numpy_array":
+ return np.load(storage_path / f"{name}.npy", **read_args)
+ if output_data_type == "dask_array":
+ return da.from_array(np.load(storage_path / f"{name}.npy"), **read_args)
+
+ self.log_to_debug(
+ f"Invalid output data type: {output_data_type}, for loading .npy file.",
+ )
+ raise ValueError(
+ "output_data_type must be numpy_array or dask_array, other types not supported yet",
+ )
+
+ def _load_parquet(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
+ # Check if output_data_type is supported and load cache to output_data_type
+ self.log_to_debug(f"Loading .parquet file from {storage_path}/{name}")
+ if output_data_type == "pandas_dataframe":
+ return pd.read_parquet(storage_path / f"{name}.parquet", **read_args)
+ if output_data_type == "dask_dataframe":
+ return dd.read_parquet(storage_path / f"{name}.parquet", **read_args)
+ if output_data_type == "numpy_array":
+ return pd.read_parquet(
+ storage_path / f"{name}.parquet",
+ **read_args,
+ ).to_numpy()
+ if output_data_type == "dask_array":
+ return dd.read_parquet(
+ storage_path / f"{name}.parquet",
+ **read_args,
+ ).to_dask_array()
+ if output_data_type == "polars_dataframe":
+ return pl.read_parquet(storage_path / f"{name}.parquet", **read_args)
+
+ self.log_to_debug(
+ f"Invalid output data type: {output_data_type}, for loading .parquet file.",
+ )
+ raise ValueError(
+ "output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
+ )
+
+ def _load_csv(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
+ # Check if output_data_type is supported and load cache to output_data_type
+ self.log_to_debug(f"Loading .csv file from {storage_path / name}")
+ if output_data_type == "pandas_dataframe":
+ return pd.read_csv(storage_path / f"{name}.csv", **read_args)
+ if output_data_type == "dask_dataframe":
+ return dd.read_csv(storage_path / f"{name}/*.part", **read_args)
+ if output_data_type == "polars_dataframe":
+ return pl.read_csv(storage_path / f"{name}.csv", **read_args)
+
+ self.log_to_debug(
+ f"Invalid output data type: {output_data_type}, for loading .csv file.",
+ )
+ raise ValueError(
+ "output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe, other types not supported yet",
+ )
+
+ def _load_npy_stack(self, name: str, storage_path: Path, output_data_type: str, read_args: Any) -> Any:
+ # Check if output_data_type is supported and load cache to output_data_type
+ self.log_to_debug(f"Loading .npy_stack file from {storage_path / name}")
+ if output_data_type == "dask_array":
+ return da.from_npy_stack(storage_path / name, **read_args)
+
+ self.log_to_debug(
+ f"Invalid output data type: {output_data_type}, for loading .npy_stack file.",
+ )
+ raise ValueError(
+ "output_data_type must be dask_array, other types not supported yet",
+ )
+
+ def _load_pkl(self, name: str, storage_path: Path, _output_data_type: str, read_args: Any) -> Any:
+ # Load the pickle file
+ self.log_to_debug(
+ f"Loading pickle file from {storage_path}/{name}.pkl",
+ )
+ with open(storage_path / f"{name}.pkl", "rb") as file:
+ return pickle.load(file, **read_args) # noqa: S301
+
+ def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None) -> None:
+ """Store one set of data.
+
+ :param name: The name of the cache.
+ :param data: The data to store.
+ :param cache_args: The cache arguments.
+ """
+ if not cache_args:
+ raise ValueError("cache_args is empty")
+
+ required_keys = ["storage_type", "storage_path", "output_data_type"]
+ for key in required_keys:
+ if key not in cache_args:
+ raise ValueError(f"cache_args must contain {', '.join(required_keys)}")
+
+ storage_type = cache_args["storage_type"]
+ storage_path = Path(cache_args["storage_path"])
+ output_data_type = cache_args["output_data_type"]
+ store_args = cache_args.get("store_args", {})
+
+ store_functions: Dict[str, StoreFunction] = {
+ ".npy": self._store_npy,
+ ".parquet": self._store_parquet,
+ ".csv": self._store_csv,
+ ".npy_stack": self._store_npy_stack,
+ ".pkl": self._store_pkl,
+ }
+
+ if storage_type in store_functions:
+ return store_functions[storage_type](name, storage_path, data, output_data_type, store_args)
+
+ self.log_to_debug(f"Invalid storage type: {storage_type}")
+ raise ValueError(f"storage_type is {storage_type} must be .npy, .parquet, .csv, .npy_stack, or .pkl, other types not supported yet")
+
+ def _store_npy(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
+ file_path = storage_path / f"{name}.npy"
+ self.log_to_debug(f"Storing .npy file to {file_path}")
+ if output_data_type == "numpy_array":
+ np.save(file_path, data, **store_args)
+ elif output_data_type == "dask_array":
+ np.save(file_path, data.compute(), **store_args)
+ else:
+ raise ValueError("output_data_type must be numpy_array or dask_array")
+
+ def _store_parquet(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
+ # Check if output_data_type is supported and store cache to output_data_type
+ self.log_to_debug(f"Storing .parquet file to {storage_path / name}")
+ if output_data_type in {"pandas_dataframe", "dask_dataframe"}:
+ data.to_parquet(storage_path / f"{name}.parquet", **store_args)
+ elif output_data_type == "numpy_array":
+ pd.DataFrame(data).to_parquet(
+ storage_path / f"{name}.parquet",
+ **store_args,
+ )
+ elif output_data_type == "dask_array":
+ new_dd = dd.from_dask_array(data)
+ new_dd = new_dd.rename(
+ columns={col: str(col) for col in new_dd.columns},
+ )
+ new_dd.to_parquet(storage_path / f"{name}.parquet", **store_args)
+ elif output_data_type == "polars_dataframe":
+ data.write_parquet(storage_path / f"{name}.parquet", **store_args)
+ else:
+ self.log_to_debug(
+ f"Invalid output data type: {output_data_type}, for storing .parquet file.",
+ )
+ raise ValueError(
+ "output_data_type must be pandas_dataframe, dask_dataframe, numpy_array, dask_array, or polars_dataframe, other types not supported yet",
+ )
+
+ def _store_csv(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
+ if output_data_type == "pandas_dataframe":
+ data.to_csv(storage_path / f"{name}.csv", index=False, **store_args)
+ self.log_to_debug(f"Storing .csv file to {storage_path}/{name}.csv")
+ elif output_data_type == "dask_dataframe":
+ data.to_csv(storage_path / name, index=False, **store_args)
+ self.log_to_debug(f"Storing .csv file to {storage_path}/{name}")
+ elif output_data_type == "polars_dataframe":
+ data.write_csv(storage_path / f"{name}.csv", **store_args)
+ self.log_to_debug(f"Storing .csv file to {storage_path}/{name}.csv")
+ else:
+ raise ValueError("output_data_type must be pandas_dataframe, dask_dataframe, or polars_dataframe")
+
+ def _store_npy_stack(self, name: str, storage_path: Path, data: Any, output_data_type: str, store_args: Any) -> None:
+ # Handling npy_stack case differently as it might need a different path structure
+ storage_path /= name # Treat name as a directory here
+ self.log_to_debug(f"Storing .npy_stack file to {storage_path}")
+ if output_data_type == "dask_array":
+ da.to_npy_stack(storage_path, data, **store_args)
+ else:
+ raise ValueError("output_data_type must be dask_array")
+
+ def _store_pkl(self, name: str, storage_path: Path, data: Any, _output_data_type: str, store_args: Any) -> None:
+ file_path = storage_path / f"{name}.pkl"
+ self.log_to_debug(f"Storing pickle file to {file_path}")
+ with open(file_path, "wb") as f:
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL, **store_args)
diff --git a/epochlib/data/__init__.py b/epochlib/data/__init__.py
new file mode 100644
index 0000000..e00a0ad
--- /dev/null
+++ b/epochlib/data/__init__.py
@@ -0,0 +1,10 @@
+"""Module containing data related classes and functions."""
+
+from .enum_data_format import Data, DataRetrieval
+from .pipeline_dataset import PipelineDataset
+
+__all__ = [
+ "Data",
+ "DataRetrieval",
+ "PipelineDataset",
+]
diff --git a/epochlib/data/enum_data_format.py b/epochlib/data/enum_data_format.py
new file mode 100644
index 0000000..3c66bca
--- /dev/null
+++ b/epochlib/data/enum_data_format.py
@@ -0,0 +1,45 @@
+"""Module containing classes to allow for the creation of enum based retrieval data formats."""
+
+from dataclasses import dataclass, field
+from enum import IntFlag
+from typing import Any
+
+import numpy as np
+import numpy.typing as npt
+
+
+class DataRetrieval(IntFlag):
+ """Class to select which data to retrieve in Data."""
+
+
+@dataclass
+class Data:
+ """Class to describe a data format.
+
+ :param retrieval: What data to retrieve
+ """
+
+ retrieval: DataRetrieval = field(init=False)
+
+ def __getitem__(self, idx: int | npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any] | list[Any]:
+ """Get item from the data.
+
+ :param idx: Index to retrieve
+ :return: Relevant data
+ """
+ raise NotImplementedError("__getitem__ should be implemented when inheriting from Data.")
+
+ def __getitems__(self, indices: npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any]:
+ """Retrieve items for all indices based on specified retrieval flags.
+
+ :param indices: List of indices to retrieve
+ :return: Relevant data
+ """
+ raise NotImplementedError("__getitems__ should be implemented when inheriting from Data.")
+
+ def __len__(self) -> int:
+ """Return length of the data.
+
+ :return: Length of data
+ """
+ raise NotImplementedError("__len__ should be implemented when inheriting from Data.")
diff --git a/epochlib/data/pipeline_dataset.py b/epochlib/data/pipeline_dataset.py
new file mode 100644
index 0000000..61636d4
--- /dev/null
+++ b/epochlib/data/pipeline_dataset.py
@@ -0,0 +1,137 @@
+"""Module that contains a dataset that can take a training pipeline."""
+
+import logging
+from dataclasses import dataclass
+from typing import Any, Callable, Sequence, Tuple, TypeVar
+
+import numpy as np
+import numpy.typing as npt
+
+from epochlib.data.enum_data_format import Data, DataRetrieval
+from epochlib.training.training import TrainingPipeline
+from epochlib.training.training_block import TrainingBlock
+
+try:
+ from torch.utils.data import Dataset
+except ImportError:
+ """User doesn't require torch"""
+
+T = TypeVar("T")
+DataTuple = Tuple[T, T]
+
+
+@dataclass
+class PipelineDataset(Dataset[Tuple[T, T]]):
+ """Pipeline dataset takes in a pipeline to be able to process the original data.
+
+ Useful for lazy loading data for training, where computing would take a long time or the device would run out of memory.
+
+ :param retrieval: Data retrieval object.
+ :param retrieval_type: Data retrieval enum
+ :param steps: Steps to apply to the dataset
+ :param result_formatter:
+
+ :param x: Input data
+ :param y: Labels
+ :param indices: Indices to use
+
+ """
+
+ x: Data | None = None
+ y: Data | None = None
+ indices: npt.NDArray[np.int_] | None = None
+
+ retrieval: list[str] | None = None
+ retrieval_type: DataRetrieval | None = None
+ steps: Sequence[TrainingBlock] | None = None
+ result_formatter: Callable[[Any], Any] = lambda a: a
+
+ def __post_init__(self) -> None:
+ """Set up the dataset."""
+ if self.retrieval is None:
+ raise ValueError("Retrieval object must be set.")
+ if self.retrieval_type is None:
+ raise ValueError("Retrieval type must be set.")
+
+ # Setup data retrieval
+ self._retrieval_enum = getattr(self.retrieval_type, self.retrieval[0])
+ for retrieval in self.retrieval[1:]:
+ self._retrieval_enum = self._retrieval_enum | getattr(self.retrieval_type, retrieval)
+
+ # Setup pipeline
+ self.setup_pipeline(use_augmentations=False)
+
+ def initialize(self, x: Data, y: Data, indices: list[int] | npt.NDArray[np.int_] | None = None) -> None:
+ """Set up the dataset for training.
+
+ :param x: X data to initialize with
+ :param y: Y data to initialize with
+ :param indices: Indices to filter on
+ """
+ self.x = x
+ self.y = y
+ self.indices = np.array(indices, dtype=np.int32) if isinstance(indices, list) else indices
+
+ def setup_pipeline(self, *, use_augmentations: bool) -> None:
+ """Set whether to use the augmentations.
+
+ :param use_augmentations: Whether to use augmentations while passing data through pipeline
+ """
+ self._enabled_steps: Sequence[TrainingBlock] = []
+
+ if self.steps is not None:
+ for step in self.steps:
+ if not hasattr(step, "is_augmentation"):
+ continue
+ if (step.is_augmentation and use_augmentations) or not step.is_augmentation:
+ self._enabled_steps.append(step)
+
+ self._pipeline = TrainingPipeline(steps=self._enabled_steps)
+ logging.getLogger("TrainingPipeline").setLevel(logging.WARNING)
+
+ def __len__(self) -> int:
+ """Get the length of the dataset."""
+ if self.x is None:
+ raise ValueError("Dataset is not initialized.")
+ if self.indices is None:
+ return len(self.x)
+ return len(self.indices)
+
+ def __getitem__(self, idx: int | list[int] | npt.NDArray[np.int_]) -> tuple[Any, Any]:
+ """Get an item from the dataset.
+
+ :param idx: Index to retrieve
+ :return: Data and labels at the Index
+ """
+ if not isinstance(idx, (int | np.integer)):
+ return self.__getitems__(idx)
+
+ if self.x is None:
+ raise ValueError("Dataset not initialized or has no x data.")
+ if self.indices is not None:
+ idx = self.indices[idx]
+
+ self.x.retrieval = self._retrieval_enum
+ x = np.expand_dims(self.x[idx], axis=0)
+ y = np.expand_dims(self.y[idx], axis=0) if self.y is not None else None
+
+ x, y = self._pipeline.train(x, y)
+ return self.result_formatter(x)[0], self.result_formatter(y)[0] if y is not None else None
+
+ def __getitems__(self, indices: list[int] | npt.NDArray[np.int_]) -> tuple[Any, Any]:
+ """Get items from the dataset.
+
+ :param indices: The indices to retrieve
+ :return: Data and labels at the indices.
+ """
+ if self.x is None:
+ raise ValueError("Dataset not initialized or has no x data.")
+ if self.indices is not None:
+ indices = self.indices[indices]
+
+ self.x.retrieval = self._retrieval_enum
+ x = self.x[indices]
+ y = self.y[indices] if self.y is not None else None
+
+ x, y = self._pipeline.train(x, y)
+ return self.result_formatter(x), self.result_formatter(y) if y is not None else None
diff --git a/epochalyst/ensemble.py b/epochlib/ensemble.py
similarity index 75%
rename from epochalyst/ensemble.py
rename to epochlib/ensemble.py
index 400986c..69cb652 100644
--- a/epochalyst/ensemble.py
+++ b/epochlib/ensemble.py
@@ -2,13 +2,13 @@
from typing import Any
-from agogos.training import ParallelTrainingSystem
-
-from epochalyst.caching import CacheArgs
+from epochlib.caching import CacheArgs
+from epochlib.model import ModelPipeline
+from epochlib.pipeline import ParallelTrainingSystem
class EnsemblePipeline(ParallelTrainingSystem):
- """EnsemblePipeline is the class used to create the pipeline for the model. (Currently same implementation as agogos pipeline).
+ """EnsemblePipeline is the class used to create the pipeline for the model.
:param steps: Trainers to ensemble
"""
@@ -22,7 +22,7 @@ def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
if len(self.steps) == 0:
return False
- return all(step.get_x_cache_exists(cache_args) for step in self.steps)
+ return all(isinstance(step, ModelPipeline) and step.get_x_cache_exists(cache_args) for step in self.steps)
def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of y cache.
@@ -33,9 +33,9 @@ def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
if len(self.steps) == 0:
return False
- return all(step.get_y_cache_exists(cache_args) for step in self.steps)
+ return all(isinstance(step, ModelPipeline) and step.get_y_cache_exists(cache_args) for step in self.steps)
- def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401
+ def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
"""Concatenate the trained data.
:param original_data: First input data
diff --git a/epochalyst/logging/__init__.py b/epochlib/logging/__init__.py
similarity index 100%
rename from epochalyst/logging/__init__.py
rename to epochlib/logging/__init__.py
diff --git a/epochalyst/logging/logger.py b/epochlib/logging/logger.py
similarity index 97%
rename from epochalyst/logging/logger.py
rename to epochlib/logging/logger.py
index 8836e70..ceab1ad 100644
--- a/epochalyst/logging/logger.py
+++ b/epochlib/logging/logger.py
@@ -2,9 +2,11 @@
import logging
import os
+from dataclasses import dataclass
from typing import Any, Mapping
+@dataclass
class Logger:
"""Logger base class for logging methods.
@@ -21,7 +23,7 @@ def log_to_external(self, message: dict[str, Any], **kwargs: Any) -> None: # Log
def external_define_metric(self, metric: str, metric_type: str) -> None: # Defines an external metric
"""
- def __init__(self) -> None:
+ def __post_init__(self) -> None:
"""Initialize the logger."""
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(logging.DEBUG)
diff --git a/epochalyst/model.py b/epochlib/model.py
similarity index 85%
rename from epochalyst/model.py
rename to epochlib/model.py
index f91b6c7..9d3f85f 100644
--- a/epochalyst/model.py
+++ b/epochlib/model.py
@@ -2,13 +2,12 @@
from typing import Any
-from agogos.training import Pipeline
-
-from epochalyst.caching import CacheArgs
+from epochlib.caching import CacheArgs, Cacher
+from epochlib.pipeline import Pipeline
class ModelPipeline(Pipeline):
- """ModelPipeline is the class used to create the pipeline for the model. (Currently same implementation as agogos pipeline).
+ """ModelPipeline is the class used to create the pipeline for the model.
:param x_sys: The system to transform the input data.
:param y_sys: The system to transform the label data.
@@ -24,7 +23,7 @@ def __post_init__(self) -> None:
"""
return super().__post_init__()
- def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Train the system.
:param x: The input to the system.
@@ -33,7 +32,7 @@ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa:
"""
return super().train(x, y, **train_args)
- def predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401
+ def predict(self, x: Any, **pred_args: Any) -> Any:
"""Predict the output of the system.
:param x: The input to the system.
@@ -47,7 +46,7 @@ def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
:param cache_args: Cache arguments
:return: Whether cache exists
"""
- if self.x_sys is None:
+ if self.x_sys is None or not isinstance(self.x_sys, Cacher):
return False
return self.x_sys.cache_exists(self.x_sys.get_hash(), cache_args)
@@ -57,7 +56,7 @@ def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
:param cache_args: Cache arguments
:return: Whether cache exists
"""
- if self.y_sys is None:
+ if self.y_sys is None or not isinstance(self.y_sys, Cacher):
return False
return self.y_sys.cache_exists(self.y_sys.get_hash(), cache_args)
diff --git a/epochlib/pipeline/__init__.py b/epochlib/pipeline/__init__.py
new file mode 100644
index 0000000..745d851
--- /dev/null
+++ b/epochlib/pipeline/__init__.py
@@ -0,0 +1,21 @@
+"""Core pipeline functionality for training and transforming data."""
+
+from .core import Base, Block, ParallelSystem, SequentialSystem
+from .training import ParallelTrainingSystem, Pipeline, Trainer, TrainingSystem, TrainType
+from .transforming import ParallelTransformingSystem, Transformer, TransformingSystem, TransformType
+
+__all__ = [
+ "TrainType",
+ "Trainer",
+ "TrainingSystem",
+ "ParallelTrainingSystem",
+ "Pipeline",
+ "TransformType",
+ "Transformer",
+ "TransformingSystem",
+ "ParallelTransformingSystem",
+ "Base",
+ "SequentialSystem",
+ "ParallelSystem",
+ "Block",
+]
diff --git a/epochlib/pipeline/core.py b/epochlib/pipeline/core.py
new file mode 100644
index 0000000..b4e2935
--- /dev/null
+++ b/epochlib/pipeline/core.py
@@ -0,0 +1,285 @@
+"""This module contains the core classes for all classes in the epochlib package."""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Sequence
+
+from joblib import hash
+
+
+@dataclass
+class Base:
+ """The Base class is the base class for all classes in the epochlib package.
+
+ Methods:
+ .. code-block:: python
+ def get_hash(self) -> str:
+ # Get the hash of base
+
+ def get_parent(self) -> Any:
+ # Get the parent of base.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of base
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+ def __post_init__(self) -> None:
+ """Initialize the block."""
+ self.set_hash("")
+ self.set_parent(None)
+ self.set_children([])
+
+ def set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the block.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = hash(prev_hash + str(self))
+
+ def get_hash(self) -> str:
+ """Get the hash of the block.
+
+ :return: The hash of the block.
+ """
+ return self._hash
+
+ def get_parent(self) -> Any:
+ """Get the parent of the block.
+
+ :return: Parent of the block
+ """
+ return self._parent
+
+ def get_children(self) -> Sequence[Any]:
+ """Get the children of the block.
+
+ :return: Children of the block
+ """
+ return self._children
+
+ def save_to_html(self, file_path: Path) -> None:
+ """Write html representation of class to file.
+
+ :param file_path: File path to write to
+ """
+ html = self._repr_html_()
+ with open(file_path, "w") as file:
+ file.write(html)
+
+ def set_parent(self, parent: Any) -> None:
+ """Set the parent of the block.
+
+ :param parent: Parent of the block
+ """
+ self._parent = parent
+
+ def set_children(self, children: Sequence[Any]) -> None:
+ """Set the children of the block.
+
+ :param children: Children of the block
+ """
+ self._children = children
+
+ def _repr_html_(self) -> str:
+ """Return representation of class in html format.
+
+ :return: String representation of html
+ """
+ html = ""
+ html += f"
Class: {self.__class__.__name__}
"
+ html += "
"
+ html += f"- Hash: {self.get_hash()}
"
+ html += f"- Parent: {self.get_parent()}
"
+ html += "- Children: "
+ if self.get_children():
+ html += "
"
+ for child in self.get_children():
+ html += f"- {child._repr_html_()}
"
+ html += "
"
+ else:
+ html += "None"
+ html += " "
+ html += "
"
+ html += "
"
+ return html
+
+
+class Block(Base):
+ """The Block class is the base class for all blocks.
+
+ Methods:
+ .. code-block:: python
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+
+@dataclass
+class ParallelSystem(Base):
+ """The System class is the base class for all systems.
+
+ Parameters:
+ - steps (list[_Base]): The steps in the system.
+ - weights (list[float]): Weights of steps in the system, if not specified they are equal.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def concat(self, original_data: Any), data_to_concat: Any, weight: float = 1.0) -> Any:
+ # Specifies how to concat data after parallel computations
+
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+ steps: list[Base] = field(default_factory=list)
+ weights: list[float] = field(default_factory=list)
+
+ def __post_init__(self) -> None:
+ """Post init function of _System class."""
+ # Sort the steps by name, to ensure consistent ordering of parallel computations
+ self.steps = sorted(self.steps, key=lambda x: x.__class__.__name__)
+
+ super().__post_init__()
+
+ # Set parent and children
+ for step in self.steps:
+ step.set_parent(self)
+
+ # Set weights if they exist
+ if len(self.weights) == len(self.get_steps()):
+ [w / sum(self.weights) for w in self.weights]
+ else:
+ num_steps = len(self.get_steps())
+ self.weights = [1 / num_steps for x in self.steps]
+
+ self.set_children(self.steps)
+
+ def get_steps(self) -> list[Base]:
+ """Return list of steps of ParallelSystem.
+
+ :return: List of steps
+ """
+ return self.steps
+
+ def get_weights(self) -> list[float]:
+ """Return list of weights of ParallelSystem.
+
+ :return: List of weights
+ """
+ if len(self.get_steps()) != len(self.weights):
+ raise TypeError("Mismatch between weights and steps")
+ return self.weights
+
+ def set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the system.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = prev_hash
+
+ # System has no steps and as such hash should not be affected
+ if len(self.steps) == 0:
+ return
+
+ # System is one step and should act as such
+ if len(self.steps) == 1:
+ step = self.steps[0]
+ step.set_hash(prev_hash)
+ self._hash = step.get_hash()
+ return
+
+ # System has at least two steps so hash should become a combination
+ total = self.get_hash()
+ for step in self.steps:
+ step.set_hash(prev_hash)
+ total = total + step.get_hash()
+
+ self._hash = hash(total)
+
+ @abstractmethod
+ def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
+ """Concatenate the transformed data.
+
+ :param original_data: The first input data.
+ :param data_to_concat: The second input data.
+ :param weight: Weight of data to concat
+ :return: The concatenated data.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement concat method.")
+
+
+@dataclass
+class SequentialSystem(Base):
+ """The SequentialSystem class is the base class for all systems.
+
+ Parameters:
+ - steps (list[_Base]): The steps in the system.
+
+ Methods:
+ .. code-block:: python
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+ """
+
+ steps: Sequence[Base] = field(default_factory=list)
+
+ def __post_init__(self) -> None:
+ """Post init function of _System class."""
+ super().__post_init__()
+
+ # Set parent and children
+ for step in self.steps:
+ step.set_parent(self)
+
+ self.set_children(self.steps)
+
+ def get_steps(self) -> Sequence[Base]:
+ """Return list of steps of _ParallelSystem.
+
+ :return: List of steps
+ """
+ return self.steps
+
+ def set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the system.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = prev_hash
+
+ # Set hash of each step using previous hash and then update hash with last step
+ for step in self.steps:
+ step.set_hash(self.get_hash())
+ self._hash = step.get_hash()
diff --git a/epochlib/pipeline/training.py b/epochlib/pipeline/training.py
new file mode 100644
index 0000000..dabada3
--- /dev/null
+++ b/epochlib/pipeline/training.py
@@ -0,0 +1,436 @@
+"""This module contains classes for training and predicting on data."""
+
+import copy
+import warnings
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import Any
+
+from joblib import hash
+
+from .core import Base, Block, ParallelSystem, SequentialSystem
+from .transforming import TransformingSystem
+
+
+class TrainType(Base):
+ """Abstract train type describing a class that implements two functions train and predict."""
+
+ @abstractmethod
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the block.
+
+ :param x: The input data.
+ :param y: The target variable.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement train method.")
+
+ @abstractmethod
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the target variable.
+
+ :param x: The input data.
+ :return: The predictions.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement predict method.")
+
+
+class Trainer(TrainType, Block):
+ """The trainer block is for blocks that need to train on two inputs and predict on one.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ # Train the block.
+
+ @abstractmethod
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ # Predict the target variable.
+
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import Trainer
+
+
+ class MyTrainer(Trainer):
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ # Train the block.
+ return x, y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ # Predict the target variable.
+ return x
+
+
+ my_trainer = MyTrainer()
+ predictions, labels = my_trainer.train(x, y)
+ predictions = my_trainer.predict(x)
+ """
+
+
+class TrainingSystem(TrainType, SequentialSystem):
+ """A system that trains on the input data and labels.
+
+ Parameters:
+ - steps (list[TrainType]): The steps in the system.
+
+ Methods:
+ .. code-block:: python
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # Train the system.
+
+ def predict(self, x: Any, **pred_args: Any) -> Any: # Predict the output of the system.
+
+ def get_hash(self) -> str:
+ # Get the hash of the block.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the block.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the block
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import TrainingSystem
+
+ trainer_1 = CustomTrainer()
+ trainer_2 = CustomTrainer()
+
+ training_system = TrainingSystem(steps=[trainer_1, trainer_2])
+ trained_x, trained_y = training_system.train(x, y)
+ predictions = training_system.predict(x)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the TrainingSystem class."""
+ # Assert all steps are a subclass of Trainer
+ for step in self.steps:
+ if not isinstance(
+ step,
+ (TrainType),
+ ):
+ raise TypeError(f"step: {step} is not an instance of TrainType")
+
+ super().__post_init__()
+
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the system.
+
+ :param x: The input to the system.
+ :param y: The output of the system.
+ :return: The input and output of the system.
+ """
+ set_of_steps = set()
+ for step in self.steps:
+ step_name = step.__class__.__name__
+ set_of_steps.add(step_name)
+
+ if set_of_steps != set(train_args.keys()):
+ # Raise a warning and print all the keys that do not match
+ warnings.warn(f"The following steps do not exist but were given in the kwargs: {set(train_args.keys()) - set_of_steps}", UserWarning, stacklevel=2)
+
+ # Loop through each step and call the train method
+ for step in self.steps:
+ step_name = step.__class__.__name__
+
+ step_args = train_args.get(step_name, {})
+ if isinstance(step, (TrainType)):
+ x, y = step.train(x, y, **step_args)
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return x, y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the output of the system.
+
+ :param x: The input to the system.
+ :return: The output of the system.
+ """
+ set_of_steps = set()
+ for step in self.steps:
+ step_name = step.__class__.__name__
+ set_of_steps.add(step_name)
+
+ if set_of_steps != set(pred_args.keys()):
+ # Raise a warning and print all the keys that do not match
+ warnings.warn(f"The following steps do not exist but were given in the kwargs: {set(pred_args.keys()) - set_of_steps}", UserWarning, stacklevel=2)
+
+ # Loop through each step and call the predict method
+ for step in self.steps:
+ step_name = step.__class__.__name__
+
+ step_args = pred_args.get(step_name, {})
+
+ if isinstance(step, (TrainType)):
+ x = step.predict(x, **step_args)
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return x
+
+
+class ParallelTrainingSystem(TrainType, ParallelSystem):
+ """A system that trains the input data in parallel.
+
+ Parameters:
+ - steps (list[Trainer | TrainingSystem | ParallelTrainingSystem]): The steps in the system.
+ - weights (list[float]): The weights of steps in the system, if not specified they are all equal.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def concat(self, data1: Any, data2: Any) -> Any: # Concatenate the transformed data.
+
+ def train(self, x: Any, y: Any) -> tuple[Any, Any]: # Train the system.
+
+ def predict(self, x: Any, pred_args: dict[str, Any] = {}) -> Any: # Predict the output of the system.
+
+ def concat_labels(self, data1: Any, data2: Any) -> Any: # Concatenate the transformed labels.
+
+ def get_hash(self) -> str: # Get the hash of the system.
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import ParallelTrainingSystem
+
+ trainer_1 = CustomTrainer()
+ trainer_2 = CustomTrainer()
+
+
+ class CustomParallelTrainingSystem(ParallelTrainingSystem):
+ def concat(self, data1: Any, data2: Any) -> Any:
+ # Concatenate the transformed data.
+ return data1 + data2
+
+
+ training_system = CustomParallelTrainingSystem(steps=[trainer_1, trainer_2])
+ trained_x, trained_y = training_system.train(x, y)
+ predictions = training_system.predict(x)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the ParallelTrainingSystem class."""
+ # Assert all steps correct instances
+ for step in self.steps:
+ if not isinstance(step, (TrainType)):
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ super().__post_init__()
+
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the system.
+
+ :param x: The input to the system.
+ :param y: The expected output of the system.
+ :return: The input and output of the system.
+ """
+ # Loop through each step and call the train method
+ out_x, out_y = None, None
+ for i, step in enumerate(self.steps):
+ step_name = step.__class__.__name__
+
+ step_args = train_args.get(step_name, {})
+
+ if isinstance(step, (TrainType)):
+ step_x, step_y = step.train(copy.deepcopy(x), copy.deepcopy(y), **step_args)
+ out_x, out_y = (
+ self.concat(out_x, step_x, self.get_weights()[i]),
+ self.concat_labels(out_y, step_y, self.get_weights()[i]),
+ )
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return out_x, out_y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the output of the system.
+
+ :param x: The input to the system.
+ :return: The output of the system.
+ """
+ # Loop through each trainer and call the predict method
+ out_x = None
+ for i, step in enumerate(self.steps):
+ step_name = step.__class__.__name__
+
+ step_args = pred_args.get(step_name, {})
+
+ if isinstance(step, (TrainType)):
+ step_x = step.predict(copy.deepcopy(x), **step_args)
+ out_x = self.concat(out_x, step_x, self.get_weights()[i])
+ else:
+ raise TypeError(f"{step} is not an instance of TrainType")
+
+ return out_x
+
+ def concat_labels(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any:
+ """Concatenate the transformed labels. Will use concat method if not overridden.
+
+ :param original_data: The first input data.
+ :param data_to_concat: The second input data.
+ :param weight: Weight of data to concat
+ :return: The concatenated data.
+ """
+ return self.concat(original_data, data_to_concat, weight)
+
+
+@dataclass
+class Pipeline(TrainType):
+ """A pipeline of systems that can be trained and predicted.
+
+ Parameters:
+ - x_sys (TransformingSystem | None): The system to transform the input data.
+ - y_sys (TransformingSystem | None): The system to transform the labelled data.
+ - train_sys (TrainingSystem | None): The system to train the data.
+ - pred_sys (TransformingSystem | None): The system to transform the predictions.
+ - label_sys (TransformingSystem | None): The system to transform the labels.
+
+ Methods:
+ .. code-block:: python
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ # Train the system.
+
+ def predict(self, x: Any, **pred_args) -> Any:
+ # Predict the output of the system.
+
+ def get_hash(self) -> str:
+ # Get the hash of the pipeline
+
+ def get_parent(self) -> Any:
+ # Get the parent of the pipeline
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the pipeline
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import Pipeline
+
+ x_sys = CustomTransformingSystem()
+ y_sys = CustomTransformingSystem()
+ train_sys = CustomTrainingSystem()
+ pred_sys = CustomTransformingSystem()
+ label_sys = CustomTransformingSystem()
+
+ pipeline = Pipeline(x_sys=x_sys, y_sys=y_sys, train_sys=train_sys, pred_sys=pred_sys, label_sys=label_sys)
+ trained_x, trained_y = pipeline.train(x, y)
+ predictions = pipeline.predict(x)
+ """
+
+ x_sys: TransformingSystem | None = None
+ y_sys: TransformingSystem | None = None
+ train_sys: Trainer | TrainingSystem | ParallelTrainingSystem | None = None
+ pred_sys: TransformingSystem | None = None
+ label_sys: TransformingSystem | None = None
+
+ def __post_init__(self) -> None:
+ """Post initialization function of the Pipeline."""
+ super().__post_init__()
+
+ # Set children and parents
+ children = []
+ systems = [
+ self.x_sys,
+ self.y_sys,
+ self.train_sys,
+ self.pred_sys,
+ self.label_sys,
+ ]
+
+ for sys in systems:
+ if sys is not None:
+ sys.set_parent(self)
+ children.append(sys)
+
+ self.set_children(children)
+
+ def train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
+ """Train the system.
+
+ :param x: The input to the system.
+ :param y: The expected output of the system.
+ :param train_args: The arguments to pass to the training system. (Default is {})
+ :return: The input and output of the system.
+ """
+ if self.x_sys is not None:
+ x = self.x_sys.transform(x, **train_args.get("x_sys", {}))
+ if self.y_sys is not None:
+ y = self.y_sys.transform(y, **train_args.get("y_sys", {}))
+ if self.train_sys is not None:
+ x, y = self.train_sys.train(x, y, **train_args.get("train_sys", {}))
+ if self.pred_sys is not None:
+ x = self.pred_sys.transform(x, **train_args.get("pred_sys", {}))
+ if self.label_sys is not None:
+ y = self.label_sys.transform(y, **train_args.get("label_sys", {}))
+
+ return x, y
+
+ def predict(self, x: Any, **pred_args: Any) -> Any:
+ """Predict the output of the system.
+
+ :param x: The input to the system.
+ :param pred_args: The arguments to pass to the prediction system. (Default is {})
+ :return: The output of the system.
+ """
+ if self.x_sys is not None:
+ x = self.x_sys.transform(x, **pred_args.get("x_sys", {}))
+ if self.train_sys is not None:
+ x = self.train_sys.predict(x, **pred_args.get("train_sys", {}))
+ if self.pred_sys is not None:
+ x = self.pred_sys.transform(x, **pred_args.get("pred_sys", {}))
+
+ return x
+
+ def _set_hash(self, prev_hash: str) -> None:
+ """Set the hash of the pipeline.
+
+ :param prev_hash: The hash of the previous block.
+ """
+ self._hash = prev_hash
+
+ xy_hash = ""
+ if self.x_sys is not None:
+ self.x_sys.set_hash(self.get_hash())
+ xy_hash += self.x_sys.get_hash()
+ if self.y_sys is not None:
+ self.y_sys.set_hash(self.get_hash())
+ xy_hash += self.y_sys.get_hash()[::-1] # Reversed for edge case where you have two pipelines with the same system but one in x the other in y
+
+ if xy_hash != "":
+ self._hash = hash(xy_hash)
+
+ if self.train_sys is not None:
+ self.train_sys.set_hash(self.get_hash())
+ training_hash = self.train_sys.get_hash()
+ if training_hash != "":
+ self._hash = hash(self._hash + training_hash)
+
+ predlabel_hash = ""
+ if self.pred_sys is not None:
+ self.pred_sys.set_hash(self.get_hash())
+ predlabel_hash += self.pred_sys.get_hash()
+ if self.label_sys is not None:
+ self.label_sys.set_hash(self.get_hash())
+ predlabel_hash += self.label_sys.get_hash()
+
+ if predlabel_hash != "":
+ self._hash = hash(predlabel_hash)
diff --git a/epochlib/pipeline/transforming.py b/epochlib/pipeline/transforming.py
new file mode 100644
index 0000000..529aa26
--- /dev/null
+++ b/epochlib/pipeline/transforming.py
@@ -0,0 +1,209 @@
+"""This module contains the classes for transforming data in the epochlib package."""
+
+import copy
+import warnings
+from abc import abstractmethod
+from typing import Any
+
+from .core import Base, Block, ParallelSystem, SequentialSystem
+
+
+class TransformType(Base):
+ """Abstract transform type describing a class that implements the transform function."""
+
+ @abstractmethod
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ """Transform the input data.
+
+ :param data: The input data.
+ :param transform_args: Keyword arguments.
+ :return: The transformed data.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} does not implement transform method.")
+
+
+class Transformer(TransformType, Block):
+ """The transformer block transforms any data it could be x or y data.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ # Transform the input data.
+
+ def get_hash(self) -> str:
+ # Get the hash of the Transformer
+
+ def get_parent(self) -> Any:
+ # Get the parent of the Transformer
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the Transformer
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import Transformer
+
+
+ class MyTransformer(Transformer):
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ # Transform the input data.
+ return data
+
+
+ my_transformer = MyTransformer()
+ transformed_data = my_transformer.transform(data)
+ """
+
+
+class TransformingSystem(TransformType, SequentialSystem):
+ """A system that transforms the input data.
+
+ Parameters:
+ - steps (list[Transformer | TransformingSystem | ParallelTransformingSystem]): The steps in the system.
+
+ Implements the following methods:
+ .. code-block:: python
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ # Transform the input data.
+
+ def get_hash(self) -> str:
+ # Get the hash of the TransformingSystem
+
+ def get_parent(self) -> Any:
+ # Get the parent of the TransformingSystem
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the TransformingSystem
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import TransformingSystem
+
+ transformer_1 = CustomTransformer()
+ transformer_2 = CustomTransformer()
+
+ transforming_system = TransformingSystem(steps=[transformer_1, transformer_2])
+ transformed_data = transforming_system.transform(data)
+ predictions = transforming_system.predict(data)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the TransformingSystem class."""
+ # Assert all steps are a subclass of Transformer
+ for step in self.steps:
+ if not isinstance(step, (TransformType)):
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ super().__post_init__()
+
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ """Transform the input data.
+
+ :param data: The input data.
+ :return: The transformed data.
+ """
+ set_of_steps = set()
+ for step in self.steps:
+ step_name = step.__class__.__name__
+ set_of_steps.add(step_name)
+ if set_of_steps != set(transform_args.keys()):
+ # Raise a warning and print all the keys that do not match
+ warnings.warn(f"The following steps do not exist but were given in the kwargs: {set(transform_args.keys()) - set_of_steps}", stacklevel=2)
+
+ # Loop through each step and call the transform method
+ for step in self.steps:
+ step_name = step.__class__.__name__
+
+ step_args = transform_args.get(step_name, {})
+ if isinstance(step, (TransformType)):
+ data = step.transform(data, **step_args)
+ else:
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ return data
+
+
+class ParallelTransformingSystem(TransformType, ParallelSystem):
+ """A system that transforms the input data in parallel.
+
+ Parameters:
+ - steps (list[Transformer | TransformingSystem | ParallelTransformingSystem]): The steps in the system.
+ - weights (list[float]): Weights of steps in system, if not specified they are all equal.
+
+ Methods:
+ .. code-block:: python
+ @abstractmethod
+ def concat(self, original_data: Any), data_to_concat: Any, weight: float = 1.0) -> Any:
+ # Specifies how to concat data after parallel computations
+
+ def get_hash(self) -> str:
+ # Get the hash of the ParallelTransformingSystem.
+
+ def get_parent(self) -> Any:
+ # Get the parent of the ParallelTransformingSystem.
+
+ def get_children(self) -> list[Any]:
+ # Get the children of the ParallelTransformingSystem
+
+ def save_to_html(self, file_path: Path) -> None:
+ # Save html format to file_path
+
+ Usage:
+ .. code-block:: python
+ from epochlib.pipeline import ParallelTransformingSystem
+
+ transformer_1 = CustomTransformer()
+ transformer_2 = CustomTransformer()
+
+
+ class CustomParallelTransformingSystem(ParallelTransformingSystem):
+ def concat(self, data1: Any, data2: Any) -> Any:
+ # Concatenate the transformed data.
+ return data1 + data2
+
+
+ transforming_system = CustomParallelTransformingSystem(steps=[transformer_1, transformer_2])
+
+ transformed_data = transforming_system.transform(data)
+ """
+
+ def __post_init__(self) -> None:
+ """Post init method for the ParallelTransformingSystem class."""
+ # Assert all steps are a subclass of Transformer or TransformingSystem
+ for step in self.steps:
+ if not isinstance(step, (TransformType)):
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ super().__post_init__()
+
+ def transform(self, data: Any, **transform_args: Any) -> Any:
+ """Transform the input data.
+
+ :param data: The input data.
+ :return: The transformed data.
+ """
+ # Loop through each step and call the transform method
+ out_data = None
+ if len(self.get_steps()) == 0:
+ return data
+
+ for i, step in enumerate(self.get_steps()):
+ step_name = step.__class__.__name__
+
+ step_args = transform_args.get(step_name, {})
+
+ if isinstance(step, (TransformType)):
+ step_data = step.transform(copy.deepcopy(data), **step_args)
+ out_data = self.concat(out_data, step_data, self.get_weights()[i])
+ else:
+ raise TypeError(f"{step} is not an instance of TransformType")
+
+ return out_data
diff --git a/epochalyst/training/__init__.py b/epochlib/training/__init__.py
similarity index 80%
rename from epochalyst/training/__init__.py
rename to epochlib/training/__init__.py
index 46e26fa..a51b642 100644
--- a/epochalyst/training/__init__.py
+++ b/epochlib/training/__init__.py
@@ -1,4 +1,4 @@
-"""Module containing training functionality for the epochalyst package."""
+"""Module containing training functionality for the epochlib package."""
from .pretrain_block import PretrainBlock
from .torch_trainer import TorchTrainer, TrainValidationDataset
diff --git a/epochalyst/training/_custom_data_parallel.py b/epochlib/training/_custom_data_parallel.py
similarity index 100%
rename from epochalyst/training/_custom_data_parallel.py
rename to epochlib/training/_custom_data_parallel.py
diff --git a/epochalyst/training/augmentation/__init__.py b/epochlib/training/augmentation/__init__.py
similarity index 74%
rename from epochalyst/training/augmentation/__init__.py
rename to epochlib/training/augmentation/__init__.py
index 700a6a0..92d6e51 100644
--- a/epochalyst/training/augmentation/__init__.py
+++ b/epochlib/training/augmentation/__init__.py
@@ -1,7 +1,7 @@
"""Module containing implementation for augmentations."""
-from epochalyst.training.augmentation.image_augmentations import CutMix, MixUp
-from epochalyst.training.augmentation.time_series_augmentations import (
+from epochlib.training.augmentation.image_augmentations import CutMix, MixUp
+from epochlib.training.augmentation.time_series_augmentations import (
AddBackgroundNoiseWrapper,
CutMix1D,
EnergyCutmix,
diff --git a/epochalyst/training/augmentation/image_augmentations.py b/epochlib/training/augmentation/image_augmentations.py
similarity index 98%
rename from epochalyst/training/augmentation/image_augmentations.py
rename to epochlib/training/augmentation/image_augmentations.py
index bf042f7..86c27a1 100644
--- a/epochalyst/training/augmentation/image_augmentations.py
+++ b/epochlib/training/augmentation/image_augmentations.py
@@ -6,7 +6,7 @@
import torch
-def get_kornia_mix() -> Any: # noqa: ANN401
+def get_kornia_mix() -> Any:
"""Return kornia mix."""
try:
import kornia
diff --git a/epochalyst/training/augmentation/time_series_augmentations.py b/epochlib/training/augmentation/time_series_augmentations.py
similarity index 100%
rename from epochalyst/training/augmentation/time_series_augmentations.py
rename to epochlib/training/augmentation/time_series_augmentations.py
diff --git a/epochalyst/training/augmentation/utils.py b/epochlib/training/augmentation/utils.py
similarity index 98%
rename from epochalyst/training/augmentation/utils.py
rename to epochlib/training/augmentation/utils.py
index 2dcd370..908e6ec 100644
--- a/epochalyst/training/augmentation/utils.py
+++ b/epochlib/training/augmentation/utils.py
@@ -12,7 +12,7 @@
import torch
-from epochalyst.training.utils.recursive_repr import recursive_repr
+from epochlib.training.utils.recursive_repr import recursive_repr
def get_audiomentations() -> ModuleType:
diff --git a/epochalyst/training/models/__init__.py b/epochlib/training/models/__init__.py
similarity index 61%
rename from epochalyst/training/models/__init__.py
rename to epochlib/training/models/__init__.py
index 165fb70..655d4ab 100644
--- a/epochalyst/training/models/__init__.py
+++ b/epochlib/training/models/__init__.py
@@ -1,7 +1,9 @@
"""Module for reusable models or wrappers."""
+from .conv1d_bn_relu import Conv1dBnRelu
from .timm import Timm
__all__ = [
"Timm",
+ "Conv1dBnRelu",
]
diff --git a/epochlib/training/models/conv1d_bn_relu.py b/epochlib/training/models/conv1d_bn_relu.py
new file mode 100644
index 0000000..e1e5278
--- /dev/null
+++ b/epochlib/training/models/conv1d_bn_relu.py
@@ -0,0 +1,35 @@
+"""Conv1dBnRelu block for 1d cnn layer with batch normalization and relu."""
+
+from torch import Tensor, nn
+
+
+class Conv1dBnRelu(nn.Module):
+ """Conv1dBnRelu model."""
+
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, *, is_bn: bool = True) -> None:
+ """Initialize Conv1dBnRelu block.
+
+ :param in_channels: Number of in channels
+ :param out_channels: Number of out channels
+ :param kernel_size: Number of kernels
+ :param stride: Stride length
+ :param padding: Padding size
+ :param is_bn: Whether to use batch norm
+ """
+ super().__init__()
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
+ self.is_bn = is_bn
+ if self.is_bn:
+ self.bn1 = nn.BatchNorm1d(out_channels, eps=5e-3, momentum=0.1)
+ self.relu = nn.ReLU()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward tensor function.
+
+ :param x: Input tensor
+ :return: Output tensor
+ """
+ x = self.conv1(x)
+ if self.is_bn:
+ x = self.bn1(x)
+ return self.relu(x)
diff --git a/epochalyst/training/models/timm.py b/epochlib/training/models/timm.py
similarity index 100%
rename from epochalyst/training/models/timm.py
rename to epochlib/training/models/timm.py
diff --git a/epochalyst/training/pretrain_block.py b/epochlib/training/pretrain_block.py
similarity index 95%
rename from epochalyst/training/pretrain_block.py
rename to epochlib/training/pretrain_block.py
index 7aaef74..3b69a06 100644
--- a/epochalyst/training/pretrain_block.py
+++ b/epochlib/training/pretrain_block.py
@@ -34,7 +34,7 @@ def train_split_hash(self, train_indices: list[int]) -> str: # Split the hash on
Usage:
.. code-block:: python
- from epochalyst.pipeline.model.training.pretrain_block import PretrainBlock
+ from epochlib.pipeline.model.training.pretrain_block import PretrainBlock
class CustomPretrainBlock(PretrainBlock):
@@ -54,7 +54,7 @@ def custom_predict(self, x: Any, **pred_args: Any) -> Any:
test_size: float = 0.2
@abstractmethod
- def pretrain_train(self, x: Any, y: Any, train_indices: list[int], *, save_pretrain: bool = True, save_pretrain_with_split: bool = False) -> tuple[Any, Any]: # noqa: ANN401
+ def pretrain_train(self, x: Any, y: Any, train_indices: list[int], *, save_pretrain: bool = True, save_pretrain_with_split: bool = False) -> tuple[Any, Any]:
"""Train pretrain block method.
:param x: The input to the system.
@@ -67,7 +67,7 @@ def pretrain_train(self, x: Any, y: Any, train_indices: list[int], *, save_pretr
f"Train method not implemented for {self.__class__.__name__}",
)
- def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Call the pretrain train method.
:param x: The input to the system.
diff --git a/epochalyst/training/torch_trainer.py b/epochlib/training/torch_trainer.py
similarity index 93%
rename from epochalyst/training/torch_trainer.py
rename to epochlib/training/torch_trainer.py
index 7f49c05..a2ba369 100644
--- a/epochalyst/training/torch_trainer.py
+++ b/epochlib/training/torch_trainer.py
@@ -1,5 +1,6 @@
"""TorchTrainer is a module that allows for the training of Torch models."""
+import contextlib
import copy
import functools
import gc
@@ -7,7 +8,7 @@
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
-from typing import Annotated, Any, Literal, TypeVar
+from typing import Annotated, Any, Literal, Tuple, TypeVar
import numpy as np
import numpy.typing as npt
@@ -19,6 +20,8 @@
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
+from epochlib.data import Data
+
from ._custom_data_parallel import _CustomDataParallel
from .training_block import TrainingBlock
from .utils import _get_onnxrt, _get_openvino, batch_to_device
@@ -61,6 +64,10 @@ class TorchTrainer(TrainingBlock):
- `checkpointing_keep_every` (int): Keep every i'th checkpoint (1 to keep all, 0 to keep only the last one)
- `checkpointing_resume_if_exists` (bool): Resume training if a checkpoint exists
+ Parameters Precision
+ ----------
+ - `use_mixed_precision` (bool): Whether to use mixed precision for the model training
+
Parameters Misc
----------
- `to_predict` (str): Whether to predict on the 'validation' set, 'all' data or 'none'
@@ -127,7 +134,7 @@ def update_model_directory(model_directory: str) -> None:
Usage:
.. code-block:: python
- from epochalyst.pipeline.model.training.torch_trainer import TorchTrainer
+ from epochlib.pipeline.model.training.torch_trainer import TorchTrainer
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
@@ -161,6 +168,7 @@ def log_to_terminal(self, message: str) -> None:
criterion: nn.Module
scheduler: Callable[[Optimizer], LRScheduler] | None = None
dataloader_args: dict[str, Any] = field(default_factory=dict, repr=False)
+ dataset: functools.partial[Dataset[Tuple[Tensor, Tensor]]] | None = None
# Training parameters
epochs: Annotated[int, Gt(0)] = 10
@@ -173,6 +181,9 @@ def log_to_terminal(self, message: str) -> None:
checkpointing_keep_every: Annotated[int, Gt(0)] = field(default=0, init=True, repr=False, compare=False)
checkpointing_resume_if_exists: bool = field(default=True, init=True, repr=False, compare=False)
+ # Precision
+ use_mixed_precision: bool = field(default=False)
+
# Misc
model_name: str | None = None # No spaces allowed
trained_models_directory: PathLike[str] = field(default=Path("tm/"), repr=False, compare=False)
@@ -233,6 +244,12 @@ def __post_init__(self) -> None:
self.last_val_loss = np.inf
self.lowest_val_loss = np.inf
+ # Mixed precision
+ if self.use_mixed_precision:
+ self.log_to_terminal("Using mixed precision training.")
+ self.scaler = torch.GradScaler(device=self.device.type)
+ torch.set_float32_matmul_precision("high")
+
# Check validity of model_name
if " " in self.model_name:
raise ValueError("Spaces in model_name not allowed")
@@ -281,7 +298,7 @@ def custom_train(self, x: npt.NDArray[np.float32], y: npt.NDArray[np.float32], *
self._load_model()
# Return the predictions
- return self._predict_after_train(
+ return self.predict_after_train(
x,
y,
train_dataset,
@@ -337,7 +354,7 @@ def custom_train(self, x: npt.NDArray[np.float32], y: npt.NDArray[np.float32], *
if save_model:
self._save_model()
- return self._predict_after_train(
+ return self.predict_after_train(
x,
y,
train_dataset,
@@ -346,7 +363,7 @@ def custom_train(self, x: npt.NDArray[np.float32], y: npt.NDArray[np.float32], *
validation_indices,
)
- def _predict_after_train(
+ def predict_after_train(
self,
x: npt.NDArray[np.float32],
y: npt.NDArray[np.float32],
@@ -392,7 +409,7 @@ def _predict_after_train(
case _:
raise ValueError("to_predict should be either 'validation', 'all' or 'none")
- def custom_predict(self, x: Any, **pred_args: Any) -> npt.NDArray[np.float32]: # noqa: ANN401
+ def custom_predict(self, x: Any, **pred_args: Any) -> npt.NDArray[np.float32]:
"""Predict on the validation data.
:param x: The input to the system.
@@ -448,7 +465,7 @@ def predict_on_loader(
:param loader: The loader to predict on.
:return: The predictions.
"""
- self.log_to_terminal("Predicting on the validation data")
+ self.log_to_terminal("Running inference on the given dataloader")
self.model.eval()
predictions = []
# Create a new dataloader from the dataset of the input dataloader with collate_fn
@@ -517,8 +534,8 @@ def get_hash(self) -> str:
def create_datasets(
self,
- x: npt.NDArray[np.float32],
- y: npt.NDArray[np.float32],
+ x: npt.NDArray[np.float32] | Data,
+ y: npt.NDArray[np.float32] | Data,
train_indices: list[int],
validation_indices: list[int],
) -> tuple[Dataset[tuple[Tensor, ...]], Dataset[tuple[Tensor, ...]]]:
@@ -530,27 +547,31 @@ def create_datasets(
:param validation_indices: The indices to validate on.
:return: The training and validation datasets.
"""
- train_dataset = TensorDataset(
- torch.from_numpy(x[train_indices]),
- torch.from_numpy(y[train_indices]),
- )
- validation_dataset = TensorDataset(
- torch.from_numpy(x[validation_indices]),
- torch.from_numpy(y[validation_indices]),
- )
+ if self.dataset is None:
+ train_dataset = TensorDataset(
+ torch.from_numpy(x[train_indices]),
+ torch.from_numpy(y[train_indices]),
+ )
+ validation_dataset = TensorDataset(
+ torch.from_numpy(x[validation_indices]),
+ torch.from_numpy(y[validation_indices]),
+ )
+ return train_dataset, validation_dataset
- return train_dataset, validation_dataset
+ return self.dataset(x[train_indices], y[train_indices]), self.dataset(x[validation_indices], y[validation_indices])
def create_prediction_dataset(
self,
- x: npt.NDArray[np.float32],
+ x: npt.NDArray[np.float32] | Data,
) -> Dataset[tuple[Tensor, ...]]:
"""Create the prediction dataset.
:param x: The input data.
:return: The prediction dataset.
"""
- return TensorDataset(torch.from_numpy(x))
+ if self.dataset is None:
+ return TensorDataset(torch.from_numpy(x))
+ return self.dataset(x)
def create_dataloaders(
self,
@@ -615,7 +636,7 @@ def _training_loop(
for epoch in range(start_epoch, self.epochs):
# Train using train_loader
- train_loss = self._train_one_epoch(train_loader, epoch)
+ train_loss = self.train_one_epoch(train_loader, epoch)
self.log_to_debug(f"Epoch {epoch} Train Loss: {train_loss}")
train_losses.append(train_loss)
@@ -642,7 +663,7 @@ def _training_loop(
# Compute validation loss
if len(validation_loader) > 0:
- self.last_val_loss = self._val_one_epoch(
+ self.last_val_loss = self.val_one_epoch(
validation_loader,
desc=f"Epoch {epoch} Valid",
)
@@ -681,7 +702,7 @@ def _training_loop(
# Log the trained epochs to wandb if we finished training
self.log_to_external(message={self.wrap_log(f"Epochs{fold_no}"): epoch + 1})
- def _train_one_epoch(
+ def train_one_epoch(
self,
dataloader: DataLoader[tuple[Tensor, ...]],
epoch: int,
@@ -706,13 +727,19 @@ def _train_one_epoch(
y_batch = batch_to_device(y_batch, self.y_tensor_type, self.device)
# Forward pass
- y_pred = self.model(X_batch).squeeze(1)
- loss = self.criterion(y_pred, y_batch)
+ with torch.autocast(self.device.type) if self.use_mixed_precision else contextlib.nullcontext(): # type: ignore[attr-defined]
+ y_pred = self.model(X_batch).squeeze(1)
+ loss = self.criterion(y_pred, y_batch)
# Backward pass
self.initialized_optimizer.zero_grad()
- loss.backward()
- self.initialized_optimizer.step()
+ if self.use_mixed_precision:
+ self.scaler.scale(loss).backward()
+ self.scaler.step(self.initialized_optimizer)
+ self.scaler.update()
+ else:
+ loss.backward()
+ self.initialized_optimizer.step()
# Print tqdm
losses.append(loss.item())
@@ -724,7 +751,7 @@ def _train_one_epoch(
return sum(losses) / len(losses)
- def _val_one_epoch(
+ def val_one_epoch(
self,
dataloader: DataLoader[tuple[Tensor, ...]],
desc: str,
diff --git a/epochalyst/training/training.py b/epochlib/training/training.py
similarity index 92%
rename from epochalyst/training/training.py
rename to epochlib/training/training.py
index 5453719..0b35ba0 100644
--- a/epochalyst/training/training.py
+++ b/epochlib/training/training.py
@@ -1,19 +1,20 @@
"""TrainingPipeline for creating a sequential pipeline of TrainType classes."""
+from dataclasses import dataclass
from typing import Any
-from agogos.training import TrainingSystem, TrainType
-
-from epochalyst.caching import CacheArgs, Cacher
+from epochlib.caching import CacheArgs, Cacher
+from epochlib.pipeline import TrainingSystem, TrainType
+@dataclass
class TrainingPipeline(TrainingSystem, Cacher):
"""The training pipeline. This is the class used to create the pipeline for the training of the model.
:param steps: The steps to train the model.
"""
- def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]:
"""Train the system.
:param x: The input to the system.
@@ -36,7 +37,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
# Furthest step
for i, step in enumerate(self.get_steps()):
# Check if step is instance of Cacher and if cache_args exists
- if not isinstance(step, Cacher) or not isinstance(step, TrainType):
+ if not isinstance(step, TrainType) or not isinstance(step, Cacher):
self.log_to_debug(f"{step} is not instance of Cacher or TrainType")
continue
@@ -72,7 +73,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
return x, y
- def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any: # noqa: ANN401
+ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any:
"""Predict the output of the system.
:param x: The input to the system.
@@ -90,7 +91,7 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)
# Retrieve furthest step calculated
for i, step in enumerate(self.get_steps()):
# Check if step is instance of Cacher and if cache_args exists
- if not isinstance(step, Cacher) or not isinstance(step, TrainType):
+ if not isinstance(step, TrainType) or not isinstance(step, Cacher):
self.log_to_debug(f"{step} is not instance of Cacher or TrainType")
continue
diff --git a/epochalyst/training/training_block.py b/epochlib/training/training_block.py
similarity index 92%
rename from epochalyst/training/training_block.py
rename to epochlib/training/training_block.py
index d64cf8f..c37d556 100644
--- a/epochalyst/training/training_block.py
+++ b/epochlib/training/training_block.py
@@ -3,9 +3,8 @@
from abc import abstractmethod
from typing import Any
-from agogos.training import Trainer
-
-from epochalyst.caching import CacheArgs, Cacher
+from epochlib.caching import CacheArgs, Cacher
+from epochlib.pipeline import Trainer
class TrainingBlock(Trainer, Cacher):
@@ -41,7 +40,7 @@ def predict(self, x: Any, cache_args: dict[str, Any] = {}, **pred_args: Any) ->
Usage:
.. code-block:: python
- from epochalyst.pipeline.model.training.training_block import TrainingBlock
+ from epochlib.pipeline.model.training.training_block import TrainingBlock
class CustomTrainingBlock(TrainingBlock):
def custom_train(self, x: Any, y: Any) -> tuple[Any, Any]:
@@ -58,7 +57,7 @@ def custom_predict(self, x: Any) -> Any:
x = custom_training_block.predict(x)
"""
- def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_args: Any) -> tuple[Any, Any]:
"""Train the model.
:param x: The input data.
@@ -92,7 +91,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
return x, y
@abstractmethod
- def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: # noqa: ANN401
+ def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]:
"""Train the model.
:param x: The input data.
@@ -103,7 +102,7 @@ def custom_train(self, x: Any, y: Any, **train_args: Any) -> tuple[Any, Any]: #
f"Custom transform method not implemented for {self.__class__}",
)
- def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any: # noqa: ANN401
+ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) -> Any:
"""Predict using the model.
:param x: The input data.
@@ -129,7 +128,7 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)
return x
@abstractmethod
- def custom_predict(self, x: Any, **pred_args: Any) -> Any: # noqa: ANN401
+ def custom_predict(self, x: Any, **pred_args: Any) -> Any:
"""Predict using the model.
:param x: The input data.
diff --git a/epochalyst/training/utils/__init__.py b/epochlib/training/utils/__init__.py
similarity index 100%
rename from epochalyst/training/utils/__init__.py
rename to epochlib/training/utils/__init__.py
diff --git a/epochalyst/training/utils/get_dependencies.py b/epochlib/training/utils/get_dependencies.py
similarity index 87%
rename from epochalyst/training/utils/get_dependencies.py
rename to epochlib/training/utils/get_dependencies.py
index b4f3239..2bd7273 100644
--- a/epochalyst/training/utils/get_dependencies.py
+++ b/epochlib/training/utils/get_dependencies.py
@@ -3,7 +3,7 @@
from typing import Any
-def _get_onnxrt() -> Any: # noqa: ANN401
+def _get_onnxrt() -> Any:
"""Return onnxruntime."""
try:
import onnxruntime as onnxrt
@@ -17,7 +17,7 @@ def _get_onnxrt() -> Any: # noqa: ANN401
return onnxrt
-def _get_openvino() -> Any: # noqa: ANN401
+def _get_openvino() -> Any:
"""Return openvino."""
try:
import openvino
diff --git a/epochalyst/training/utils/recursive_repr.py b/epochlib/training/utils/recursive_repr.py
similarity index 100%
rename from epochalyst/training/utils/recursive_repr.py
rename to epochlib/training/utils/recursive_repr.py
diff --git a/epochalyst/training/utils/tensor_functions.py b/epochlib/training/utils/tensor_functions.py
similarity index 100%
rename from epochalyst/training/utils/tensor_functions.py
rename to epochlib/training/utils/tensor_functions.py
diff --git a/epochalyst/transformation/__init__.py b/epochlib/transformation/__init__.py
similarity index 100%
rename from epochalyst/transformation/__init__.py
rename to epochlib/transformation/__init__.py
diff --git a/epochalyst/transformation/transformation.py b/epochlib/transformation/transformation.py
similarity index 92%
rename from epochalyst/transformation/transformation.py
rename to epochlib/transformation/transformation.py
index cbd342c..a8515e7 100644
--- a/epochalyst/transformation/transformation.py
+++ b/epochlib/transformation/transformation.py
@@ -3,9 +3,8 @@
from dataclasses import dataclass
from typing import Any
-from agogos.transforming import TransformingSystem, TransformType
-
-from epochalyst.caching.cacher import CacheArgs, Cacher
+from epochlib.caching.cacher import CacheArgs, Cacher
+from epochlib.pipeline import TransformingSystem, TransformType
@dataclass
@@ -43,7 +42,7 @@ def get_hash(self) -> str: # Get the hash of the pipeline.
Usage:
.. code-block:: python
- from epochalyst.pipeline.model.transformation import TransformationPipeline
+ from epochlib.pipeline.model.transformation import TransformationPipeline
class MyTransformationPipeline(TransformationPipeline):
def log_to_terminal(self, message: str) -> None:
@@ -60,7 +59,7 @@ def log_to_terminal(self, message: str) -> None:
title: str = "Transformation Pipeline" # The title of the pipeline since transformation pipeline can be used for multiple purposes. (Feature, Label, etc.)
- def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any: # noqa: ANN401
+ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any:
"""Transform the input data.
:param data: The input data.
@@ -81,7 +80,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_
# Furthest step
for i, step in enumerate(self.get_steps()):
# Check if step is instance of Cacher and if cache_args exists
- if not isinstance(step, Cacher) or not isinstance(step, TransformType):
+ if not isinstance(step, TransformType) or not isinstance(step, Cacher):
self.log_to_debug(f"{step} is not instance of Cacher or TransformType")
continue
diff --git a/epochalyst/transformation/transformation_block.py b/epochlib/transformation/transformation_block.py
similarity index 91%
rename from epochalyst/transformation/transformation_block.py
rename to epochlib/transformation/transformation_block.py
index a163c7b..fbe82cf 100644
--- a/epochalyst/transformation/transformation_block.py
+++ b/epochlib/transformation/transformation_block.py
@@ -3,9 +3,8 @@
from abc import abstractmethod
from typing import Any
-from agogos.transforming import Transformer
-
-from epochalyst.caching.cacher import CacheArgs, Cacher
+from epochlib.caching.cacher import CacheArgs, Cacher
+from epochlib.pipeline import Transformer
class TransformationBlock(Transformer, Cacher):
@@ -36,7 +35,7 @@ def transform(self, data: Any, cache_args: dict[str, Any] = {}, **transform_args
Usage:
.. code-block:: python
- from epochalyst.pipeline.model.transformation.transformation_block import TransformationBlock
+ from epochlib.pipeline.model.transformation.transformation_block import TransformationBlock
class CustomTransformationBlock(TransformationBlock):
def custom_transform(self, data: Any) -> Any:
@@ -55,7 +54,7 @@ def custom_transform(self, data: Any) -> Any:
data = custom_transformation_block.transform(data, cache=cache_args)
"""
- def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any: # noqa: ANN401
+ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_args: Any) -> Any:
"""Transform the input data using a custom method.
:param data: The input data.
@@ -78,7 +77,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_
return data
@abstractmethod
- def custom_transform(self, data: Any, **transform_args: Any) -> Any: # noqa: ANN401
+ def custom_transform(self, data: Any, **transform_args: Any) -> Any:
"""Transform the input data using a custom method.
:param data: The input data.
diff --git a/pyproject.toml b/pyproject.toml
index ca4ff9a..11e74bb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
-name = "epochalyst"
-version = "4.0"
+name = "epochlib"
+version = "5.0.0"
authors = [
{ name = "Jasper van Selm", email = "jmvanselm@gmail.com" },
{ name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" },
@@ -12,7 +12,7 @@ authors = [
{ name = "Kristóf Sandor", email = "emherk512@gmail.com"},
{ name = "Daniel De Dios Allegue", email = "danieldediosallegue@gmail.com"}
]
-description = "Epochalyst is the base for Team Epoch competitions."
+description = "Epoch Libraries is the base for Team Epoch competitions."
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.10"
@@ -30,10 +30,10 @@ classifiers = [
]
dependencies = [
"torch>=2.1.0",
- "agogos>=0.4",
"joblib>=1.4.0",
"annotated-types>=0.6.0",
"typing-extensions>=4.9.0; python_version<'3.12'",
+ "numpy>= 1.22.4, < 3",
]
[project.optional-dependencies]
@@ -41,9 +41,6 @@ image = [
"kornia>=0.7.2",
"timm>=0.9.16",
]
-numpy = [
- "numpy>=1.22.4, < 2",
-]
pandas = [
"pandas[performance, parquet]>=2.0.0",
]
@@ -67,18 +64,18 @@ audio = [
[project.urls]
Homepage = "https://teamepoch.ai/"
-Documentation = "https://TeamEpochGithub.github.io/epochalyst/"
-Repository = "https://github.com/TeamEpochGithub/epochalyst"
-Download = "https://pypi.org/project/epochalyst/#files"
-Issues = "https://github.com/TeamEpochGithub/epochalyst/issues"
-"Release notes" = "https://github.com/TeamEpochGithub/epochalyst/releases"
+Documentation = "https://TeamEpochGithub.github.io/epochlib/"
+Repository = "https://github.com/TeamEpochGithub/epochlib"
+Download = "https://pypi.org/project/epochlib/#files"
+Issues = "https://github.com/TeamEpochGithub/epochlib/issues"
+"Release notes" = "https://github.com/TeamEpochGithub/epochlib/releases"
[tool.rye]
managed = true
lock-with-sources = true
dev-dependencies = [
"pre-commit>=3.7.1",
- "pytest>=8.1.1",
+ "pytest>=8.3.2",
"pytest-cov>=5.0.0",
"sphinx>=7.2.6",
"sphinx-autodoc-typehints>=2.0.0",
@@ -95,7 +92,7 @@ build-backend = "hatchling.build"
allow-direct-references = true
[tool.hatch.build.targets.wheel]
-packages = ["epochalyst"]
+packages = ["epochlib"]
[tool.pydoclint]
style = "sphinx"
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 230a345..e91e3c4 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -7,18 +7,17 @@
# all-features: true
# with-sources: true
# generate-hashes: false
+# universal: false
--index-url https://pypi.org/simple/
-e file:.
-agogos==0.4
- # via epochalyst
alabaster==0.7.16
# via sphinx
annotated-types==0.7.0
- # via epochalyst
+ # via epochlib
audiomentations==0.36.0
- # via epochalyst
+ # via epochlib
audioread==3.0.1
# via librosa
babel==2.15.0
@@ -46,10 +45,10 @@ coverage==7.5.4
# via pytest-cov
dask==2024.6.2
# via dask-expr
- # via epochalyst
+ # via epochlib
dask-expr==1.1.6
# via dask
- # via epochalyst
+ # via epochlib
decorator==5.1.1
# via librosa
distlib==0.3.8
@@ -89,12 +88,11 @@ jinja2==3.1.4
# via sphinx
# via torch
joblib==1.4.2
- # via agogos
- # via epochalyst
+ # via epochlib
# via librosa
# via scikit-learn
kornia==0.7.2
- # via epochalyst
+ # via epochlib
kornia-rs==0.1.3
# via kornia
lazy-loader==0.4
@@ -132,7 +130,7 @@ numpy==1.26.4
# via audiomentations
# via bottleneck
# via dask
- # via epochalyst
+ # via epochlib
# via librosa
# via numba
# via numexpr
@@ -166,19 +164,19 @@ nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
-nvidia-nccl-cu12==2.20.5
+nvidia-nccl-cu12==2.19.3
# via torch
-nvidia-nvjitlink-cu12==12.5.40
+nvidia-nvjitlink-cu12==12.8.61
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
onnx==1.16.1
- # via epochalyst
+ # via epochlib
onnxruntime==1.18.0
- # via epochalyst
+ # via epochlib
openvino==2024.2.0
- # via epochalyst
+ # via epochlib
openvino-telemetry==2024.1.0
# via openvino
packaging==24.1
@@ -194,7 +192,7 @@ packaging==24.1
pandas==2.2.2
# via dask
# via dask-expr
- # via epochalyst
+ # via epochlib
partd==1.4.2
# via dask
pillow==10.3.0
@@ -205,7 +203,7 @@ platformdirs==4.2.2
pluggy==1.5.0
# via pytest
polars==0.20.31
- # via epochalyst
+ # via epochlib
pooch==1.8.2
# via librosa
pre-commit==3.7.1
@@ -220,7 +218,7 @@ pycparser==2.22
pygit2==1.15.0
pygments==2.18.0
# via sphinx
-pytest==8.2.2
+pytest==8.3.2
# via pytest-cov
pytest-cov==5.0.0
python-dateutil==2.9.0.post0
@@ -280,7 +278,7 @@ sympy==1.12.1
threadpoolctl==3.5.0
# via scikit-learn
timm==1.0.7
- # via epochalyst
+ # via epochlib
tomli==2.0.1
# via coverage
# via pytest
@@ -288,19 +286,19 @@ tomli==2.0.1
toolz==0.12.1
# via dask
# via partd
-torch==2.3.1
- # via epochalyst
+torch==2.2.2
+ # via epochlib
# via kornia
# via timm
# via torchvision
-torchvision==0.18.1
+torchvision==0.17.2
# via timm
tqdm==4.66.4
# via huggingface-hub
-triton==2.3.1
+triton==2.2.0
# via torch
typing-extensions==4.12.2
- # via epochalyst
+ # via epochlib
# via huggingface-hub
# via librosa
# via torch
diff --git a/requirements.lock b/requirements.lock
index 55fabe4..f20a68e 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -7,16 +7,15 @@
# all-features: true
# with-sources: true
# generate-hashes: false
+# universal: false
--index-url https://pypi.org/simple/
-e file:.
-agogos==0.4
- # via epochalyst
annotated-types==0.7.0
- # via epochalyst
+ # via epochlib
audiomentations==0.36.0
- # via epochalyst
+ # via epochlib
audioread==3.0.1
# via librosa
bottleneck==1.4.0
@@ -35,10 +34,10 @@ coloredlogs==15.0.1
# via onnxruntime
dask==2024.6.2
# via dask-expr
- # via epochalyst
+ # via epochlib
dask-expr==1.1.6
# via dask
- # via epochalyst
+ # via epochlib
decorator==5.1.1
# via librosa
filelock==3.15.4
@@ -62,12 +61,11 @@ importlib-metadata==7.2.1
jinja2==3.1.4
# via torch
joblib==1.4.2
- # via agogos
- # via epochalyst
+ # via epochlib
# via librosa
# via scikit-learn
kornia==0.7.2
- # via epochalyst
+ # via epochlib
kornia-rs==0.1.3
# via kornia
lazy-loader==0.4
@@ -95,7 +93,7 @@ numpy==1.26.4
# via audiomentations
# via bottleneck
# via dask
- # via epochalyst
+ # via epochlib
# via librosa
# via numba
# via numexpr
@@ -129,19 +127,19 @@ nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
-nvidia-nccl-cu12==2.20.5
+nvidia-nccl-cu12==2.19.3
# via torch
-nvidia-nvjitlink-cu12==12.5.40
+nvidia-nvjitlink-cu12==12.8.61
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
onnx==1.16.1
- # via epochalyst
+ # via epochlib
onnxruntime==1.18.0
- # via epochalyst
+ # via epochlib
openvino==2024.2.0
- # via epochalyst
+ # via epochlib
openvino-telemetry==2024.1.0
# via openvino
packaging==24.1
@@ -155,7 +153,7 @@ packaging==24.1
pandas==2.2.2
# via dask
# via dask-expr
- # via epochalyst
+ # via epochlib
partd==1.4.2
# via dask
pillow==10.3.0
@@ -163,7 +161,7 @@ pillow==10.3.0
platformdirs==4.2.2
# via pooch
polars==0.20.31
- # via epochalyst
+ # via epochlib
pooch==1.8.2
# via librosa
protobuf==5.27.1
@@ -206,23 +204,23 @@ sympy==1.12.1
threadpoolctl==3.5.0
# via scikit-learn
timm==1.0.7
- # via epochalyst
+ # via epochlib
toolz==0.12.1
# via dask
# via partd
-torch==2.3.1
- # via epochalyst
+torch==2.2.2
+ # via epochlib
# via kornia
# via timm
# via torchvision
-torchvision==0.18.1
+torchvision==0.17.2
# via timm
tqdm==4.66.4
# via huggingface-hub
-triton==2.3.1
+triton==2.2.0
# via torch
typing-extensions==4.12.2
- # via epochalyst
+ # via epochlib
# via huggingface-hub
# via librosa
# via torch
diff --git a/ruff.toml b/ruff.toml
index 1db4a53..b07f196 100644
--- a/ruff.toml
+++ b/ruff.toml
@@ -32,6 +32,7 @@ ignore = [
# flake8-annotations (ANN)
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
+ "ANN401", # Allow Any type in epochlib
# flake8-errmsg (EM)
"EM101", # Exception must not use a string literal, assign to variable first
"EM102", # Exception must not use an f-string literal, assign to variable first
@@ -72,6 +73,19 @@ external = [
# Pylint (PL), Refactor (R)
"PLR2004", # Magic values in comparison allowed in tests
]
+"test_*.py" = [
+ # flake8-bandit (S)
+ "S101", # Asserts allowed in tests
+ # flake8-pytest-style (PT)
+ "PT009", # Use a regular assert instead of unittest-style {assertion}
+ # Pylint (PL), Refactor (R)
+ "PLR2004", # Magic values in comparison allowed in tests
+ #
+ # "ANN201", # Missing return value allowed in tests
+ "D100", # Docstring methods
+ "D101",
+ "D102",
+]
[lint.flake8-annotations]
allow-star-arg-any = true
diff --git a/tests/caching/test__cacher.py b/tests/caching/test__cacher.py
index f119c28..9a9f1de 100644
--- a/tests/caching/test__cacher.py
+++ b/tests/caching/test__cacher.py
@@ -5,8 +5,8 @@
import polars as pl
import pytest
-from epochalyst.caching.cacher import Cacher
-from epochalyst.logging.logger import Logger
+from epochlib.caching.cacher import Cacher
+from epochlib.logging.logger import Logger
from tests.constants import TEMP_DIR
diff --git a/tests/data/test_enum_data_format.py b/tests/data/test_enum_data_format.py
new file mode 100644
index 0000000..556f010
--- /dev/null
+++ b/tests/data/test_enum_data_format.py
@@ -0,0 +1,131 @@
+from unittest import TestCase
+from dataclasses import dataclass
+
+from epochlib.data import DataRetrieval, Data
+import numpy as np
+import numpy.typing as npt
+from typing import Any
+
+
+class TestDataRetrieval(TestCase):
+ def test_inherited_retrieval(self) -> None:
+ class ExpandedDataRetrieval(DataRetrieval):
+ BASE = 2**0
+ NEW = 2**1
+
+ self.assertTrue(ExpandedDataRetrieval.BASE == 2**0)
+ self.assertTrue(ExpandedDataRetrieval.NEW == 2**1)
+
+class TestData(TestCase):
+
+ def test___get_item__(self) -> None:
+ """Should raise an error if getitem has not been implemented."""
+ non_implemented = Data()
+ with self.assertRaises(NotImplementedError):
+ non_implemented[0]
+
+ def test___get_items__(self) -> None:
+ """Should raise an error if getitems has not been implemented."""
+ non_implemented = Data()
+ with self.assertRaises(NotImplementedError):
+ non_implemented[0:1]
+
+ def test___len__(self) -> None:
+ """Should raise an error if length has not been implemented."""
+ non_implemented = Data()
+ with self.assertRaises(NotImplementedError):
+ len(non_implemented)
+
+class TestDataRetrievalCombination(TestCase):
+
+ def test_get_empty_data(self) -> None:
+ test_data = CustomData()
+ with self.assertRaises(TypeError):
+ test_data[0]
+
+ def test_get_data1(self) -> None:
+ test_data = CustomData()
+ test_data.retrieval = TestRetrieval.BASE
+ test_data.data1 = [0,1]
+ self.assertTrue(test_data[1] == 1)
+
+ def test_get_data2(self) -> None:
+ test_data = CustomData()
+ test_data.retrieval = TestRetrieval.NEW
+ test_data.data2 = [0,1]
+ self.assertTrue(test_data[1] == 1)
+
+ def test_get_data2_with_both(self) -> None:
+ test_data = CustomData()
+ test_data.retrieval = TestRetrieval.NEW
+ test_data.data1 = [1,0]
+ test_data.data2 = [0,1]
+ self.assertTrue(test_data[1] == 1)
+
+ def test_get_items(self) -> None:
+ test_data = CustomData()
+ test_data.data1 = [0, 1, 2]
+ self.assertTrue(test_data[0:1] == [0])
+
+ def test_len(self) -> None:
+ test_data = CustomData()
+ test_data.data1 = [0,1]
+ self.assertTrue(len(test_data) == 2)
+
+
+class TestRetrieval(DataRetrieval):
+ BASE = 2**0
+ NEW = 2**1
+
+@dataclass
+class CustomData(Data):
+ data1: npt.NDArray[np.int_] | None = None
+ data2: npt.NDArray[np.int_] | None = None
+
+ def __post_init__(self) -> None:
+ self.retrieval = TestRetrieval.BASE
+
+ def __getitem__(self, idx: int | npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any] | list[Any]:
+ """Get item from the data.
+
+ :param idx: Index to retrieve
+ :return: Relevant data
+ """
+ result = []
+ if self.retrieval & TestRetrieval.BASE:
+ result.append(self.data1[idx])
+ if self.retrieval & TestRetrieval.NEW:
+ result.append(self.data2[idx])
+
+ if len(result) == 1:
+ return result[0]
+
+ return result
+
+ def __getitems__(self, indices: npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any]:
+ """Retrieve items for all indices based on specified retrieval flags.
+
+ :param indices: List of indices to retrieve
+ :return: Relevant data
+ """
+ result = []
+ if self.retrieval & TestRetrieval.BASE:
+ result.append(self.data1[indices])
+ if self.retrieval & TestRetrieval.NEW:
+ result.append(self.data2[indices])
+
+ if len(result) == 1:
+ return result[0]
+
+ return result
+
+ def __len__(self) -> int:
+ if self.data1:
+ return len(self.data1)
+ if self.data2:
+ return len(self.data2)
+ return 0
+
+
+
+
diff --git a/tests/data/test_pipeline_dataset.py b/tests/data/test_pipeline_dataset.py
new file mode 100644
index 0000000..d18bff4
--- /dev/null
+++ b/tests/data/test_pipeline_dataset.py
@@ -0,0 +1,180 @@
+from unittest import TestCase
+from dataclasses import dataclass
+
+from epochlib.data import Data, DataRetrieval, PipelineDataset
+from epochlib.training import TrainingBlock
+import numpy as np
+import numpy.typing as npt
+from typing import Any
+
+
+class TestDataRetrieval(DataRetrieval):
+ BASE = 2**0
+ FIRST = 2**1
+
+
+@dataclass
+class CustomData(Data):
+ data1: npt.NDArray[np.int_] | None = None
+ data2: npt.NDArray[np.int_] | None = None
+
+ def __post_init__(self) -> None:
+ self.retrieval = TestDataRetrieval.BASE
+
+ def __getitem__(self, idx: int | npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any] | list[Any]:
+ """Get item from the data.
+
+ :param idx: Index to retrieve
+ :return: Relevant data
+ """
+ if not isinstance(idx, (int | np.integer)):
+ return self.__getitems__(idx) # type: ignore[arg-type]
+
+ result = []
+ if self.retrieval & TestDataRetrieval.BASE:
+ result.append(self.data1[idx])
+ if self.retrieval & TestDataRetrieval.FIRST:
+ result.append(self.data2[idx])
+
+ if len(result) == 1:
+ return result[0]
+
+ return result
+
+ def __getitems__(self, indices: npt.NDArray[np.int_] | list[int] | slice) -> npt.NDArray[Any]:
+ """Retrieve items for all indices based on specified retrieval flags.
+
+ :param indices: List of indices to retrieve
+ :return: Relevant data
+ """
+ result = []
+ if self.retrieval & TestDataRetrieval.BASE:
+ result.append(self.data1[indices])
+ if self.retrieval & TestDataRetrieval.FIRST:
+ result.append(self.data2[indices])
+
+ if len(result) == 1:
+ return result[0]
+
+ return result
+
+ def __len__(self) -> int:
+ if self.data1 is not None:
+ return len(self.data1)
+ if self.data2 is not None:
+ return len(self.data2)
+ return 0
+
+
+class TestTrainingBlockNoAug(TrainingBlock):
+
+ def train(
+ self,
+ x: npt.NDArray[np.str_],
+ y: npt.NDArray[np.str_],
+ **train_args: Any,
+ ) -> tuple[npt.NDArray[np.str_], npt.NDArray[np.uint8]]:
+ """Randomize the SMILES string."""
+ return x, y
+
+ @property
+ def is_augmentation(self) -> bool:
+ """Check if augmentation is enabled."""
+ return False
+
+
+class TestTrainingBlockWithAug(TrainingBlock):
+
+ def train(
+ self,
+ x: npt.NDArray[np.str_],
+ y: npt.NDArray[np.str_],
+ **train_args: Any,
+ ) -> tuple[npt.NDArray[np.str_], npt.NDArray[np.uint8]]:
+ """Randomize the SMILES string."""
+ return x, y
+
+ @property
+ def is_augmentation(self) -> bool:
+ """Check if augmentation is enabled."""
+ return True
+
+
+class TestPipelineDataset(TestCase):
+
+ def test_initialization_errors(self) -> None:
+ with self.assertRaises(ValueError):
+ PipelineDataset()
+ with self.assertRaises(ValueError):
+ PipelineDataset(retrieval=['BASE'])
+
+ def test_initialization_steps(self) -> None:
+ pd = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval)
+ step1 = TestTrainingBlockNoAug()
+ pd_with_step = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step1])
+ self.assertEqual(pd_with_step._enabled_steps, [step1])
+
+ step_with_aug = TestTrainingBlockWithAug()
+ pd_aug = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step_with_aug])
+ self.assertEqual(pd_aug._enabled_steps, [])
+
+ pd_aug.setup_pipeline(use_augmentations=True)
+ self.assertEqual(pd_aug._enabled_steps, [step_with_aug])
+
+ def test_get_item(self) -> None:
+ test_data = CustomData()
+ test_data.data1 = [0, 1]
+ step = TestTrainingBlockNoAug()
+
+ pd_with_data = PipelineDataset(
+ retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data
+ )
+ self.assertEqual(pd_with_data[0][0], 0)
+
+ pd_with_indices = PipelineDataset(
+ retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data, indices=[1]
+ )
+ self.assertEqual(pd_with_indices[0][0], 1)
+
+ pd_no_data = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step])
+ with self.assertRaises(ValueError):
+ pd_no_data[0]
+
+ def test_get_items(self) -> None:
+ test_data = CustomData()
+ test_data.data1 = np.array([0, 1, 2])
+ step = TestTrainingBlockNoAug()
+
+ pd_with_data = PipelineDataset(
+ retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data
+ )
+ self.assertTrue((pd_with_data[[0, 1]][0] == [0, 1]).all())
+
+ pd_with_indices = PipelineDataset(
+ retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data, indices=np.array([0, 2])
+ )
+ self.assertTrue((pd_with_indices[[0, 1]][0] == [0, 2]).all())
+
+ pd_no_data = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step])
+ with self.assertRaises(ValueError):
+ pd_no_data[[0, 1]]
+
+ def test_len(self) -> None:
+ test_data = CustomData()
+ test_data.data1 = np.array([0, 1, 2])
+ step = TestTrainingBlockNoAug()
+
+ pd_with_data = PipelineDataset(
+ retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data
+ )
+ self.assertTrue(len(pd_with_data) == 3)
+
+ pd_with_data.initialize(x=test_data, y=test_data, indices=np.array([0, 2]))
+ # pd_with_indices = PipelineDataset(
+ # retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step], x=test_data, indices=np.array([0,2])
+ # )
+ self.assertTrue(len(pd_with_data) == 2)
+
+ pd_no_data = PipelineDataset(retrieval=['BASE'], retrieval_type=TestDataRetrieval, steps=[step])
+ with self.assertRaises(ValueError):
+ self.assertTrue(len(pd_no_data))
diff --git a/tests/logging/test_logger.py b/tests/logging/test_logger.py
index cfb8ca5..6682c55 100644
--- a/tests/logging/test_logger.py
+++ b/tests/logging/test_logger.py
@@ -3,7 +3,7 @@
import pytest
-from epochalyst.logging.logger import Logger
+from epochlib.logging.logger import Logger
test_string = "Test"
diff --git a/tests/pipeline/test__core.py b/tests/pipeline/test__core.py
new file mode 100644
index 0000000..8e26b22
--- /dev/null
+++ b/tests/pipeline/test__core.py
@@ -0,0 +1,170 @@
+from epochlib.pipeline import Block, Base, SequentialSystem, ParallelSystem
+from tests.pipeline.util import remove_cache_files
+from pathlib import Path
+
+
+class Test_Base:
+ def test_init(self):
+ base = Base()
+ assert base is not None
+
+ def test_set_hash(self):
+ base = Base()
+ prev_hash = base.get_hash()
+ base.set_hash("prev_hash")
+ assert base.get_hash() != prev_hash
+
+ def test_get_children(self):
+ base = Base()
+ assert base.get_children() == []
+
+ def test_get_parent(self):
+ base = Base()
+ assert base.get_parent() is None
+
+ def test__set_parent(self):
+ base = Base()
+ base.set_parent(base)
+ assert base.get_parent() == base
+
+ def test__set_children(self):
+ base = Base()
+ base.set_children([base])
+ assert base.get_children() == [base]
+
+ def test__repr_html_(self):
+ base = Base()
+ assert (
+ base._repr_html_()
+ == "Class: Base
- Hash: a00a595206d7eefcf0e87acf6e2e22ee
- Parent: None
- Children: None
"
+ )
+
+ def test_save_to_html(self):
+ html_path = Path("./tests/cache/test_html.html")
+ Path("./tests/cache/").mkdir(parents=True, exist_ok=True)
+ base = Base()
+ base.save_to_html(html_path)
+ assert Path.exists(html_path)
+ remove_cache_files()
+
+
+class TestBlock:
+ def test_block_init(self):
+ block = Block()
+ assert block is not None
+
+ def test_block_set_hash(self):
+ block = Block()
+ block.set_hash("")
+ hash1 = block.get_hash()
+ assert hash1 != ""
+ block.set_hash(hash1)
+ hash2 = block.get_hash()
+ assert hash2 != ""
+ assert hash1 != hash2
+
+ def test_block_get_hash(self):
+ block = Block()
+ block.set_hash("")
+ hash1 = block.get_hash()
+ assert hash1 != ""
+
+ def test__repr_html_(self):
+ block_instance = Block()
+
+ html_representation = block_instance._repr_html_()
+
+ assert html_representation is not None
+
+
+class TestSequentialSystem:
+ def test_system_init(self):
+ system = SequentialSystem()
+ assert system is not None
+
+ def test_system_hash_no_steps(self):
+ system = SequentialSystem()
+ assert system.get_hash() == ""
+
+ def test_system_hash_with_1_step(self):
+ block1 = Block()
+
+ system = SequentialSystem([block1])
+ assert system.get_hash() != ""
+ assert block1.get_hash() == system.get_hash()
+
+ def test_system_hash_with_2_steps(self):
+ block1 = Block()
+ block2 = Block()
+
+ system = SequentialSystem([block1, block2])
+ assert system.get_hash() != block1.get_hash()
+ assert (
+ system.get_hash() == block2.get_hash() != ""
+ )
+
+ def test_system_hash_with_3_steps(self):
+ block1 = Block()
+ block2 = Block()
+ block3 = Block()
+
+ system = SequentialSystem([block1, block2, block3])
+ assert system.get_hash() != block1.get_hash()
+ assert system.get_hash() != block2.get_hash()
+ assert block1.get_hash() != block2.get_hash()
+ assert (
+ system.get_hash() == block3.get_hash() != ""
+ )
+
+ def test__repr_html_(self):
+ block_instance = Block()
+ system_instance = SequentialSystem([block_instance, block_instance])
+ html_representation = system_instance._repr_html_()
+
+ assert html_representation is not None
+
+
+class TestParallelSystem:
+ def test_parallel_system_init(self):
+ parallel_system = ParallelSystem()
+ assert parallel_system is not None
+
+ def test_parallel_system_hash_no_steps(self):
+ system = ParallelSystem()
+ assert system.get_hash() == ""
+
+ def test_parallel_system_hash_with_1_step(self):
+ block1 = Block()
+
+ system = ParallelSystem([block1])
+ assert system.get_hash() != ""
+ assert block1.get_hash() == system.get_hash()
+
+ def test_parallel_system_hash_with_2_steps(self):
+ block1 = Block()
+ block2 = Block()
+
+ system = ParallelSystem([block1, block2])
+ assert system.get_hash() != block1.get_hash()
+ assert block1.get_hash() == block2.get_hash()
+ assert system.get_hash() != block2.get_hash()
+ assert system.get_hash() != ""
+
+ def test_parallel_system_hash_with_3_steps(self):
+ block1 = Block()
+ block2 = Block()
+ block3 = Block()
+
+ system = ParallelSystem([block1, block2, block3])
+ assert system.get_hash() != block1.get_hash()
+ assert system.get_hash() != block2.get_hash()
+ assert system.get_hash() != block3.get_hash()
+ assert block1.get_hash() == block2.get_hash() == block3.get_hash()
+ assert system.get_hash() != ""
+
+ def test_parallel_system__repr_html_(self):
+ block_instance = Block()
+ system_instance = ParallelSystem([block_instance, block_instance])
+ html_representation = system_instance._repr_html_()
+
+ assert html_representation is not None
diff --git a/tests/pipeline/test_training.py b/tests/pipeline/test_training.py
new file mode 100644
index 0000000..0b7e886
--- /dev/null
+++ b/tests/pipeline/test_training.py
@@ -0,0 +1,614 @@
+import pytest
+import warnings
+from epochlib.pipeline import Trainer, TrainingSystem, ParallelTrainingSystem, Pipeline
+from epochlib.pipeline import Transformer, TransformingSystem
+import numpy as np
+
+
+class TestTrainer:
+ def test_trainer_abstract_train(self):
+ trainer = Trainer()
+ with pytest.raises(NotImplementedError):
+ trainer.train([1, 2, 3], [1, 2, 3])
+
+ def test_trainer_abstract_predict(self):
+ trainer = Trainer()
+ with pytest.raises(NotImplementedError):
+ trainer.predict([1, 2, 3])
+
+ def test_trainer_train(self):
+ class trainerInstance(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ trainer = trainerInstance()
+ assert trainer.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_trainer_predict(self):
+ class trainerInstance(Trainer):
+ def predict(self, x):
+ return x
+
+ trainer = trainerInstance()
+ assert trainer.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_trainer_hash(self):
+ trainer = Trainer()
+ assert trainer.get_hash() != ""
+
+
+class TestTrainingSystem:
+ def test_training_system_init(self):
+ training_system = TrainingSystem()
+ assert training_system is not None
+
+ def test_training_system_init_with_steps(self):
+ class SubTrainer(Trainer):
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem(steps=[block1])
+ assert training_system is not None
+
+ def test_training_system_wrong_step(self):
+ class SubTrainer:
+ def predict(self, x):
+ return x
+
+ with pytest.raises(TypeError):
+ TrainingSystem(steps=[SubTrainer()])
+
+ def test_training_system_steps_changed_predict(self):
+ class SubTrainer:
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem()
+ training_system.steps = [block1]
+ with pytest.raises(TypeError):
+ training_system.predict([1, 2, 3])
+
+ def test_training_system_predict(self):
+ class SubTrainer(Trainer):
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem(steps=[block1])
+ assert training_system.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_trainsys_predict_with_trainer_and_trainsys(self):
+ class SubTrainer(Trainer):
+ def predict(self, x):
+ return x
+
+ block1 = SubTrainer()
+ block2 = SubTrainer()
+ block3 = TrainingSystem(steps=[block1, block2])
+ assert block2.get_parent() == block3
+ assert block1 in block3.get_children()
+ training_system = TrainingSystem(steps=[block1, block2, block3])
+ assert training_system.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_training_system_train(self):
+ class SubTrainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem(steps=[block1])
+ assert training_system.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_traiinsys_train_with_trainer_and_trainsys(self):
+ class SubTrainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ block1 = SubTrainer()
+ block2 = SubTrainer()
+ block3 = TrainingSystem(steps=[block1, block2])
+ training_system = TrainingSystem(steps=[block1, block2, block3])
+ assert training_system.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_training_system_steps_changed_train(self):
+ class SubTrainer:
+ def train(self, x, y):
+ return x, y
+
+ block1 = SubTrainer()
+ training_system = TrainingSystem()
+ training_system.steps = [block1]
+ with pytest.raises(TypeError):
+ training_system.train([1, 2, 3], [1, 2, 3])
+
+ def test_training_system_empty_hash(self):
+ training_system = TrainingSystem()
+ assert training_system.get_hash() == ""
+
+ def test_training_system_wrong_kwargs(self):
+ class Block1(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ class Block2(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TrainingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "block2": {}}
+ with pytest.warns(
+ UserWarning,
+ match="The following steps do not exist but were given in the kwargs:",
+ ):
+ system.train([1, 2, 3], [1, 2, 3], **kwargs)
+ system.predict([1, 2, 3], **kwargs)
+
+ def test_training_system_right_kwargs(self):
+ class Block1(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ class Block2(Trainer):
+ def train(self, x, y, **kwargs):
+ return x, y
+
+ def predict(self, x, **pred_args):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TrainingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "Block2": {}}
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ system.train([1, 2, 3], [1, 2, 3], **kwargs)
+ system.predict([1, 2, 3], **kwargs)
+ assert not caught_warnings
+
+
+class TestParallelTrainingSystem:
+ def test_PTrainSys_init(self):
+ system = ParallelTrainingSystem()
+
+ assert system is not None
+
+ def test_PTrainSys_init_trainers(self):
+ t1 = Trainer()
+ t2 = TrainingSystem()
+
+ system = ParallelTrainingSystem(steps=[t1, t2])
+
+ assert system is not None
+
+ def test_PTrainSys_init_wrong_trainers(self):
+ class WrongTrainer:
+ """Wrong trainer"""
+
+ t1 = WrongTrainer()
+
+ with pytest.raises(TypeError):
+ ParallelTrainingSystem(steps=[t1])
+
+ def test_PTrainSys_train(self):
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+
+ return data1 + data2
+
+ t1 = trainer()
+
+ system = pts(steps=[t1])
+
+ assert system is not None
+ assert system.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_PTrainSys_trainers(self):
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = trainer()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ assert system.train([1, 2, 3], [1, 2, 3]) == (
+ [1, 2, 3, 1, 2, 3],
+ [1, 2, 3, 1, 2, 3],
+ )
+
+ def test_PTrainSys_trainers_with_weights(self):
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class trainer2(Trainer):
+ def train(self, x, y):
+ return x * 3, y
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2 * weight
+ return data1 + data2 * weight
+
+ t1 = trainer()
+ t2 = trainer2()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ test = np.array([1, 2, 3])
+ preds, labels = system.train(test, test)
+ assert np.array_equal(preds, test * 2)
+ assert np.array_equal(labels, test)
+
+ def test_PTrainSys_predict(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+
+ system = pts(steps=[t1])
+
+ assert system is not None
+ assert system.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_PTrainSys_predict_with_trainsys(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = TrainingSystem(steps=[t1])
+
+ system = pts(steps=[t2, t1])
+
+ assert system is not None
+ assert system.predict([1, 2, 3]) == [1, 2, 3, 1, 2, 3]
+
+ def test_PTrainSys_predict_with_trainer_and_trainsys(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = trainer()
+ t3 = TrainingSystem(steps=[t1, t2])
+
+ system = pts(steps=[t1, t2, t3])
+
+ assert system is not None
+ assert t3.predict([1, 2, 3]) == [1, 2, 3]
+ assert system.predict([1, 2, 3]) == [1, 2, 3, 1, 2, 3, 1, 2, 3]
+
+ def test_PTrainSys_predictors(self):
+ class trainer(Trainer):
+ def predict(self, x):
+ return x
+
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = trainer()
+ t2 = trainer()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ assert system.predict([1, 2, 3]) == [1, 2, 3, 1, 2, 3]
+
+ def test_PTrainSys_concat_labels_throws_error(self):
+ system = ParallelTrainingSystem()
+
+ with pytest.raises(NotImplementedError):
+ system.concat_labels([1, 2, 3], [4, 5, 6])
+
+ def test_PTrainSys_step_1_changed(self):
+ system = ParallelTrainingSystem()
+
+ t1 = Transformer()
+ system.steps = [t1]
+
+ with pytest.raises(TypeError):
+ system.train([1, 2, 3], [1, 2, 3])
+
+ with pytest.raises(TypeError):
+ system.predict([1, 2, 3])
+
+ def test_PTrainSys_step_2_changed(self):
+ class pts(ParallelTrainingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+
+ return data1 + data2
+
+ system = pts()
+
+ class trainer(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ def predict(self, x):
+ return x
+
+ t1 = trainer()
+ t2 = Transformer()
+ system.steps = [t1, t2]
+
+ with pytest.raises(TypeError):
+ system.train([1, 2, 3], [1, 2, 3])
+
+ with pytest.raises(TypeError):
+ system.predict([1, 2, 3])
+
+ def test_train_parallel_hashes(self):
+ class SubTrainer1(Trainer):
+ def train(self, x, y):
+ return x, y
+
+ class SubTrainer2(Trainer):
+ def train(self, x, y):
+ return x * 2, y
+
+ block1 = SubTrainer1()
+ block2 = SubTrainer2()
+
+ system1 = ParallelTrainingSystem(steps=[block1, block2])
+ system1_copy = ParallelTrainingSystem(steps=[block1, block2])
+ system2 = ParallelTrainingSystem(steps=[block2, block1])
+ system2_copy = ParallelTrainingSystem(steps=[block2, block1])
+
+ assert system1.get_hash() == system2.get_hash()
+ assert system1.get_hash() == system1_copy.get_hash()
+ assert system2.get_hash() == system2_copy.get_hash()
+
+
+class TestPipeline:
+ def test_pipeline_init(self):
+ pipeline = Pipeline()
+ assert pipeline is not None
+
+ def test_pipeline_init_with_systems(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ label_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ label_sys=label_system,
+ )
+ assert pipeline is not None
+
+ def test_pipeline_train(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ label_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ label_sys=label_system,
+ )
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_no_y_system(self):
+ x_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_no_x_system(self):
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_no_train_system(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ post_system = TransformingSystem()
+ post_label_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=None,
+ pred_sys=post_system,
+ label_sys=post_label_system,
+ )
+ assert pipeline.train([1, 2], [1, 2]) == ([1, 2], [1, 2])
+
+ def test_pipeline_train_no_refining_system(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ pipeline = Pipeline(x_sys=x_system, y_sys=y_system, train_sys=training_system)
+ assert pipeline.train([1, 2, 3], [1, 2, 3]) == ([1, 2, 3], [1, 2, 3])
+
+ def test_pipeline_train_1_x_transform_block(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem(steps=[transform1])
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ result = pipeline.train(np.array([1, 2, 3]), [1, 2, 3])
+ assert np.array_equal(result[0], np.array([2, 4, 6])) and np.array_equal(
+ result[1], np.array([1, 2, 3])
+ )
+
+ def test_pipeline_predict(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_pipeline_predict_no_y_system(self):
+ x_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert pipeline.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_pipeline_predict_no_systems(self):
+ pipeline = Pipeline()
+ assert pipeline.predict([1, 2, 3]) == [1, 2, 3]
+
+ def test_pipeline_get_hash_no_change(self):
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ predicting_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=predicting_system,
+ )
+ assert x_system.get_hash() == ""
+ # assert y_system.get_hash() == ""
+ # assert training_system.get_hash() == ""
+ # assert predicting_system.get_hash() == ""
+ # assert pipeline.get_hash() == ""
+
+ def test_pipeline_get_hash_with_change(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem(steps=[transform1])
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem()
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert x_system.get_hash() != y_system.get_hash()
+ assert pipeline.get_hash() != ""
+
+ def test_pipeline_predict_system_hash(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem()
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem(steps=[transform1])
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert prediction_system.get_hash() != x_system.get_hash()
+ assert pipeline.get_hash() != ""
+
+ def test_pipeline_pre_post_hash(self):
+ class TransformingBlock(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ transform1 = TransformingBlock()
+ x_system = TransformingSystem(steps=[transform1])
+ y_system = TransformingSystem()
+ training_system = TrainingSystem()
+ prediction_system = TransformingSystem(steps=[transform1])
+ pipeline = Pipeline(
+ x_sys=x_system,
+ y_sys=y_system,
+ train_sys=training_system,
+ pred_sys=prediction_system,
+ )
+ assert x_system.get_hash() == prediction_system.get_hash()
+ assert pipeline.get_hash() != ""
diff --git a/tests/pipeline/test_transforming.py b/tests/pipeline/test_transforming.py
new file mode 100644
index 0000000..394d900
--- /dev/null
+++ b/tests/pipeline/test_transforming.py
@@ -0,0 +1,321 @@
+import warnings
+import numpy as np
+import pytest
+
+from epochlib.pipeline import Trainer
+from epochlib.pipeline import (
+ Transformer,
+ TransformingSystem,
+ ParallelTransformingSystem,
+)
+
+
+class TestTransformer:
+ def test_transformer_abstract(self):
+ transformer = Transformer()
+
+ with pytest.raises(NotImplementedError):
+ transformer.transform([1, 2, 3])
+
+ def test_transformer_transform(self):
+ class transformerInstance(Transformer):
+ def transform(self, data):
+ return data
+
+ transformer = transformerInstance()
+
+ assert transformer.transform([1, 2, 3]) == [1, 2, 3]
+
+ def test_transformer_hash(self):
+ transformer = Transformer()
+ assert transformer.get_hash() == "1cbcc4f2d0921b050d9b719d2beb6529"
+
+
+class TestTransformingSystem:
+ def test_transforming_system_init(self):
+ transforming_system = TransformingSystem()
+ assert transforming_system is not None
+
+ def test_transforming_system_init_with_steps(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ assert transforming_system is not None
+
+ def test_transforming_system_wrong_step(self):
+ class SubTransformer:
+ def transform(self, x):
+ return x
+
+ with pytest.raises(TypeError):
+ TransformingSystem(steps=[SubTransformer()])
+
+ def test_transforming_system_steps_changed(self):
+ class SubTransformer:
+ def transform(self, x):
+ return x
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem()
+ transforming_system.steps = [block1]
+ with pytest.raises(TypeError):
+ transforming_system.transform([1, 2, 3])
+
+ def test_transforming_system_transform_1_block(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ assert transforming_system.transform([1, 2, 3]) == [1, 2, 3]
+
+ def test_transforming_system_transform_1_block_with_args(self):
+ class SubTransformer(Transformer):
+ def transform(self, data):
+ return data
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ assert transforming_system.transform([1, 2, 3], **{"SubTransformer": {}}) == [
+ 1,
+ 2,
+ 3,
+ ]
+
+ def test_transforming_system_transform_2_blocks(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ block1 = SubTransformer()
+ block2 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1, block2])
+ result = transforming_system.transform(np.array([1, 2, 3]))
+ assert np.array_equal(result, np.array([4, 8, 12]))
+
+ def test_transformsys_with_transformsys(self):
+ class SubTransformer(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ block1 = SubTransformer()
+ block2 = TransformingSystem(steps=[block1])
+ transforming_system = TransformingSystem(steps=[block2])
+ result = transforming_system.transform(np.array([1, 2, 3]))
+ assert np.array_equal(result, np.array([2, 4, 6]))
+
+ def test_transforming_system_transform_with_args(self):
+ class SubTransformer(Transformer):
+ def transform(self, data, multiplier=2):
+ return data * multiplier
+
+ block1 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1])
+ result = transforming_system.transform(
+ np.array([1, 2, 3]), **{"SubTransformer": {"multiplier": 2}}
+ )
+ assert np.array_equal(result, np.array([2, 4, 6]))
+
+ def test_transforming_system_transform_with_args_2_blocks(self):
+ class SubTransformer(Transformer):
+ def transform(self, data, multiplier=2):
+ return data * multiplier
+
+ block1 = SubTransformer()
+ block2 = SubTransformer()
+ transforming_system = TransformingSystem(steps=[block1, block2])
+ result = transforming_system.transform(
+ np.array([1, 2, 3]), **{"SubTransformer": {"multiplier": 2}}
+ )
+ assert np.array_equal(result, np.array([4, 8, 12]))
+
+ def test_transforming_system_transform_with_recursive_args(self):
+ class SubTransformer(Transformer):
+ def transform(self, data, multiplier=2):
+ return data * multiplier
+
+ block1 = SubTransformer()
+ block2 = SubTransformer()
+ block3 = TransformingSystem(steps=[block2])
+ block4 = TransformingSystem(steps=[block3])
+ transforming_system = TransformingSystem(steps=[block1, block4])
+ assert np.array_equal(
+ transforming_system.transform(
+ np.array([1, 2, 3]), **{"SubTransformer": {"multiplier": 2}}
+ ),
+ np.array([4, 8, 12]),
+ )
+ assert np.array_equal(
+ transforming_system.transform(
+ np.array([1, 2, 3]),
+ **{
+ "TransformingSystem": {
+ "TransformingSystem": {"SubTransformer": {"multiplier": 3}}
+ }
+ },
+ ),
+ np.array([6, 12, 18]),
+ )
+
+ def test_transforming_system_empty_hash(self):
+ transforming_system = TransformingSystem()
+ assert transforming_system.get_hash() == ""
+
+ def test_transforming_system_wrong_kwargs(self):
+ class Block1(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ class Block2(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TransformingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "block2": {}}
+ with pytest.warns(
+ UserWarning,
+ match="The following steps do not exist but were given in the kwargs:",
+ ):
+ system.transform([1, 2, 3], **kwargs)
+
+ def test_transforming_system_right_kwargs(self):
+ class Block1(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ class Block2(Transformer):
+ def transform(self, x, **kwargs):
+ return x
+
+ block1 = Block1()
+ block2 = Block2()
+ system = TransformingSystem(steps=[block1, block2])
+ kwargs = {"Block1": {}, "Block2": {}}
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ system.transform([1, 2, 3], **kwargs)
+
+ assert not caught_warnings
+
+
+class TestParallelTransformingSystem:
+ def test_parallel_transforming_system(self):
+ # Create an instance of the system
+ system = ParallelTransformingSystem()
+
+ # Assert the system is an instance of ParallelTransformingSystem
+ assert isinstance(system, ParallelTransformingSystem)
+ assert system is not None
+
+ def test_parallel_transforming_system_wrong_step(self):
+ class SubTransformer:
+ def transform(self, x):
+ return x
+
+ with pytest.raises(TypeError):
+ ParallelTransformingSystem(steps=[SubTransformer()])
+
+ def test_parallel_transforming_system_transformers(self):
+ transformer1 = Transformer()
+ transformer2 = TransformingSystem()
+
+ system = ParallelTransformingSystem(steps=[transformer1, transformer2])
+ assert system is not None
+
+ def test_parallel_transforming_system_transform(self):
+ class transformer(Transformer):
+ def transform(self, data):
+ return data
+
+ class pts(ParallelTransformingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = transformer()
+
+ system = pts(steps=[t1])
+
+ assert system is not None
+ assert system.transform([1, 2, 3]) == [1, 2, 3]
+
+ def test_pts_transformers_transform(self):
+ class transformer(Transformer):
+ def transform(self, data):
+ return data
+
+ class pts(ParallelTransformingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ t1 = transformer()
+ t2 = transformer()
+
+ system = pts(steps=[t1, t2])
+
+ assert system is not None
+ assert system.transform([1, 2, 3]) == [1, 2, 3, 1, 2, 3]
+
+ def test_parallel_transforming_system_concat_throws_error(self):
+ system = ParallelTransformingSystem()
+
+ with pytest.raises(NotImplementedError):
+ system.concat([1, 2, 3], [4, 5, 6])
+
+ def test_pts_step_1_changed(self):
+ system = ParallelTransformingSystem()
+
+ t1 = Trainer()
+ system.steps = [t1]
+
+ with pytest.raises(TypeError):
+ system.transform([1, 2, 3])
+
+ def test_pts_step_2_changed(self):
+ class pts(ParallelTransformingSystem):
+ def concat(self, data1, data2, weight):
+ if data1 is None:
+ return data2
+ return data1 + data2
+
+ system = pts()
+
+ class transformer(Transformer):
+ def transform(self, data):
+ return data
+
+ t1 = transformer()
+ t2 = Trainer()
+ system.steps = [t1, t2]
+
+ with pytest.raises(TypeError):
+ system.transform([1, 2, 3])
+
+ def test_transform_parallel_hashes(self):
+ class SubTransformer1(Transformer):
+ def transform(self, x):
+ return x
+
+ class SubTransformer2(Transformer):
+ def transform(self, x):
+ return x * 2
+
+ block1 = SubTransformer1()
+ block2 = SubTransformer2()
+
+ system1 = ParallelTransformingSystem(steps=[block1, block2])
+ system1_copy = ParallelTransformingSystem(steps=[block1, block2])
+ system2 = ParallelTransformingSystem(steps=[block2, block1])
+ system2_copy = ParallelTransformingSystem(steps=[block2, block1])
+
+ assert system1.get_hash() == system2.get_hash()
+ assert system1.get_hash() == system1_copy.get_hash()
+ assert system2.get_hash() == system2_copy.get_hash()
diff --git a/tests/pipeline/util.py b/tests/pipeline/util.py
new file mode 100644
index 0000000..6eb0680
--- /dev/null
+++ b/tests/pipeline/util.py
@@ -0,0 +1,11 @@
+import glob
+import os
+
+
+def remove_cache_files():
+ files = glob.glob("tests/cache/*")
+ for f in files:
+ # If f is readme.md, skip it
+ if "README.md" in f:
+ continue
+ os.remove(f)
diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py
index b998bc9..049b047 100644
--- a/tests/test_ensemble.py
+++ b/tests/test_ensemble.py
@@ -1,14 +1,12 @@
-import shutil
-from pathlib import Path
from typing import Any
import numpy as np
import pytest
-from epochalyst import EnsemblePipeline
-from epochalyst import ModelPipeline
-from epochalyst.transformation import TransformationPipeline
-from epochalyst.transformation import TransformationBlock
+from epochlib import EnsemblePipeline
+from epochlib import ModelPipeline
+from epochlib.transformation import TransformationPipeline
+from epochlib.transformation import TransformationBlock
from tests.constants import TEMP_DIR
diff --git a/tests/test_model.py b/tests/test_model.py
index f33554b..c890850 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,13 +1,11 @@
-import shutil
-from pathlib import Path
from typing import Any
import numpy as np
import pytest
-from epochalyst import ModelPipeline
-from epochalyst.transformation import TransformationPipeline
-from epochalyst.transformation import TransformationBlock
+from epochlib import ModelPipeline
+from epochlib.transformation import TransformationPipeline
+from epochlib.transformation import TransformationBlock
from tests.constants import TEMP_DIR
diff --git a/tests/training/augmentation/test_image_augmentations.py b/tests/training/augmentation/test_image_augmentations.py
index f7a53e3..f1499e1 100644
--- a/tests/training/augmentation/test_image_augmentations.py
+++ b/tests/training/augmentation/test_image_augmentations.py
@@ -1,6 +1,6 @@
import torch
-from epochalyst.training.augmentation import image_augmentations
+from epochlib.training.augmentation import image_augmentations
class TestImageAugmentations:
diff --git a/tests/training/augmentation/test_time_series_augmentations.py b/tests/training/augmentation/test_time_series_augmentations.py
index 662ea7e..006e558 100644
--- a/tests/training/augmentation/test_time_series_augmentations.py
+++ b/tests/training/augmentation/test_time_series_augmentations.py
@@ -1,7 +1,7 @@
import numpy as np
import torch
-from epochalyst.training.augmentation import time_series_augmentations
+from epochlib.training.augmentation import time_series_augmentations
def set_torch_seed(seed: int = 42) -> None:
diff --git a/tests/training/augmentation/test_utils.py b/tests/training/augmentation/test_utils.py
index d62ff2f..71cc74d 100644
--- a/tests/training/augmentation/test_utils.py
+++ b/tests/training/augmentation/test_utils.py
@@ -1,6 +1,6 @@
import torch
-from epochalyst.training.augmentation import utils
+from epochlib.training.augmentation import utils
def set_torch_seed(seed: int = 42) -> None:
diff --git a/tests/training/models/test_conv1d_bn_relu.py b/tests/training/models/test_conv1d_bn_relu.py
new file mode 100644
index 0000000..116c535
--- /dev/null
+++ b/tests/training/models/test_conv1d_bn_relu.py
@@ -0,0 +1,23 @@
+import torch
+from unittest import TestCase
+
+from epochlib.training.models import Conv1dBnRelu
+
+
+class TestConv1dBnRelu(TestCase):
+
+ conv1d_bn_relu = Conv1dBnRelu(in_channels=3, out_channels=1)
+
+ def test_conv1d_bn_relu_init(self):
+ assert self.conv1d_bn_relu is not None
+
+ def test_conv1d_bn_relu_forward(self):
+ input = torch.rand(16, 3, 1)
+ # Check there is no error thrown
+ self.conv1d_bn_relu.forward(input)
+
+ def test_conv1d_bn_relu_forward_without_bn(self):
+ conv1d_relu = Conv1dBnRelu(in_channels=3, out_channels=1, is_bn=False)
+ input = torch.rand(16, 3, 1)
+ # Check there is no error thrown
+ conv1d_relu.forward(input)
diff --git a/tests/training/models/test_timm.py b/tests/training/models/test_timm.py
index ab847d7..dc391b9 100644
--- a/tests/training/models/test_timm.py
+++ b/tests/training/models/test_timm.py
@@ -1,6 +1,7 @@
import torch
-from epochalyst.training.models import Timm
+from epochlib.training.models import Timm
+from torch import nn
class TestTimm:
@@ -12,4 +13,12 @@ def test_timm_init(self):
def test_timm_forward(self):
input = torch.rand(16, 3, 1, 1)
+ # Should not throw error
self.timm.forward(input)
+
+ def test_timm_activation(self):
+ timm_act = Timm(in_chans=3, num_classes=3, activation=nn.ReLU(), model_name="resnet18")
+ input = torch.rand(16, 3, 1, 1)
+ # Should not throw error
+ timm_act.forward(input)
+
diff --git a/tests/training/test_pretrain_block.py b/tests/training/test_pretrain_block.py
index afe0e43..10a2bbc 100644
--- a/tests/training/test_pretrain_block.py
+++ b/tests/training/test_pretrain_block.py
@@ -1,6 +1,6 @@
import pytest
-from epochalyst.training import PretrainBlock
+from epochlib.training import PretrainBlock
class TestPretrainBlock:
diff --git a/tests/training/test_torch_trainer.py b/tests/training/test_torch_trainer.py
index 4e88d58..8bb6245 100644
--- a/tests/training/test_torch_trainer.py
+++ b/tests/training/test_torch_trainer.py
@@ -4,9 +4,9 @@
import time
from dataclasses import dataclass
-from epochalyst.training._custom_data_parallel import _CustomDataParallel
+from epochlib.training._custom_data_parallel import _CustomDataParallel
-from epochalyst.training.torch_trainer import custom_collate
+from epochlib.training.torch_trainer import custom_collate
from typing import Any
from unittest.mock import patch
@@ -14,7 +14,7 @@
import pytest
import torch
-from epochalyst.training.torch_trainer import TorchTrainer
+from epochlib.training.torch_trainer import TorchTrainer
from tests.constants import TEMP_DIR
@@ -78,6 +78,26 @@ def test_init_none_args(self):
n_folds=1,
)
+ def test_model_name_none(self):
+ with pytest.raises(ValueError):
+ TorchTrainer(
+ model=self.simple_model,
+ criterion=torch.nn.MSELoss(),
+ optimizer=self.optimizer,
+ n_folds=0,
+ model_name=None,
+ )
+
+ def test_model_name_invalid(self):
+ with pytest.raises(ValueError):
+ TorchTrainer(
+ model=self.simple_model,
+ criterion=torch.nn.MSELoss(),
+ optimizer=self.optimizer,
+ n_folds=0,
+ model_name=" ",
+ )
+
def test_init_not_implemented(self):
with pytest.raises(NotImplementedError):
TorchTrainer(
diff --git a/tests/training/test_training.py b/tests/training/test_training.py
index ec495b3..d8ba215 100644
--- a/tests/training/test_training.py
+++ b/tests/training/test_training.py
@@ -1,9 +1,9 @@
import numpy as np
import pytest
-from agogos.training import Trainer
+from epochlib.pipeline import Trainer
-from epochalyst.training import TrainingPipeline
-from epochalyst.training import TrainingBlock
+from epochlib.training import TrainingPipeline
+from epochlib.training import TrainingBlock
from tests.constants import TEMP_DIR
diff --git a/tests/training/test_training_block.py b/tests/training/test_training_block.py
index accadb0..ecd5c56 100644
--- a/tests/training/test_training_block.py
+++ b/tests/training/test_training_block.py
@@ -2,7 +2,7 @@
import pytest
-from epochalyst.training import TrainingBlock
+from epochlib.training import TrainingBlock
TEMP_DIR = Path("tests/temp")
diff --git a/tests/training/utils/test_get_dependencies.py b/tests/training/utils/test_get_dependencies.py
index 2f91c22..1c3099f 100644
--- a/tests/training/utils/test_get_dependencies.py
+++ b/tests/training/utils/test_get_dependencies.py
@@ -3,7 +3,7 @@
import pytest
-from epochalyst.training.utils import _get_openvino, _get_onnxrt
+from epochlib.training.utils import _get_openvino, _get_onnxrt
class TestGetDependencies:
diff --git a/tests/training/utils/test_recursive_repr.py b/tests/training/utils/test_recursive_repr.py
index 81bcb75..fc8d706 100644
--- a/tests/training/utils/test_recursive_repr.py
+++ b/tests/training/utils/test_recursive_repr.py
@@ -1,4 +1,4 @@
-from epochalyst.training.utils.recursive_repr import recursive_repr
+from epochlib.training.utils.recursive_repr import recursive_repr
class TestRecursiveRepr:
diff --git a/tests/training/utils/test_tensor_functions.py b/tests/training/utils/test_tensor_functions.py
new file mode 100644
index 0000000..6fbf7bd
--- /dev/null
+++ b/tests/training/utils/test_tensor_functions.py
@@ -0,0 +1,11 @@
+from unittest import TestCase
+from epochlib.training.utils import batch_to_device
+from torch import Tensor
+
+
+class TestTensorFunctions(TestCase):
+
+ def test_unsupported_tensor_type(self) -> None:
+ with self.assertRaises(ValueError):
+ batch_to_device(Tensor([0,1]), 'failure', None)
+
diff --git a/tests/transformation/test_transformation.py b/tests/transformation/test_transformation.py
index 453462f..cb2d14b 100644
--- a/tests/transformation/test_transformation.py
+++ b/tests/transformation/test_transformation.py
@@ -1,12 +1,11 @@
-import shutil
from pathlib import Path
import numpy as np
import pytest
-from agogos.transforming import Transformer
+from epochlib.pipeline import Transformer
-from epochalyst.transformation import TransformationPipeline
-from epochalyst.transformation import TransformationBlock
+from epochlib.transformation import TransformationPipeline
+from epochlib.transformation import TransformationBlock
from tests.constants import TEMP_DIR
diff --git a/tests/transformation/test_transformation_block.py b/tests/transformation/test_transformation_block.py
index 710b580..05681a0 100644
--- a/tests/transformation/test_transformation_block.py
+++ b/tests/transformation/test_transformation_block.py
@@ -6,7 +6,7 @@
import numpy as np
import pytest
-from epochalyst.transformation import TransformationBlock
+from epochlib.transformation import TransformationBlock
TEMP_DIR = Path("tests/temp")