Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ repos:
hooks:
- id: mypy
additional_dependencies:
- annotated-types==0.7.0
- joblib-stubs>=1.4.2.3.20240619
args:
- "--fast-module-lookup"
Expand Down
47 changes: 26 additions & 21 deletions agogos/_core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""This module contains the core classes for all classes in the agogos package."""

from abc import abstractmethod
from __future__ import annotations

from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import Annotated, Generic, TypeVar

from annotated_types import MinLen
from joblib import hash

_DT = TypeVar("_DT")


@dataclass
class _Base:
Expand All @@ -27,11 +32,13 @@ def save_to_html(self, file_path: Path) -> None:
# Save html format to file_path
"""

_parent: _Base | None = field(default=None, init=False)
_children: Iterable[_Base] = field(default_factory=list, init=False)
_hash: str = field(default="", init=False)

def __post_init__(self) -> None:
"""Initialize the block."""
self._set_hash("")
self._set_parent(None)
self._set_children([])

def _set_hash(self, prev_hash: str) -> None:
"""Set the hash of the block.
Expand All @@ -47,14 +54,14 @@ def get_hash(self) -> str:
"""
return self._hash

def get_parent(self) -> Any: # noqa: ANN401
def get_parent(self) -> _Base | None:
"""Get the parent of the block.

:return: Parent of the block.
"""
return self._parent

def get_children(self) -> list[Any]:
def get_children(self) -> Iterable[_Base]:
"""Get the children of the block.

:return: Children of the block
Expand All @@ -70,14 +77,14 @@ def save_to_html(self, file_path: Path) -> None:
with open(file_path, "w") as file:
file.write(html)

def _set_parent(self, parent: Any) -> None: # noqa: ANN401
def _set_parent(self, parent: _Base | None) -> None:
"""Set the parent of the block.

:param parent: Parent of the block.
"""
self._parent = parent

def _set_children(self, children: list[Any]) -> None:
def _set_children(self, children: Iterable[_Base]) -> None:
"""Set the children of the block.

:param children: Children of the block.
Expand Down Expand Up @@ -128,7 +135,7 @@ def save_to_html(self, file_path: Path) -> None:


@dataclass
class _ParallelSystem(_Base):
class _ParallelSystem(Generic[_DT], _Base):
"""The _System class is the base class for all systems.

Parameters:
Expand All @@ -154,8 +161,8 @@ def save_to_html(self, file_path: Path) -> None:
# Save html format to file_path
"""

steps: list[_Base] = field(default_factory=list)
weights: list[float] = field(default_factory=list)
steps: Annotated[Sequence[_Base], MinLen(1)] = field(default_factory=list)
weights: Annotated[Sequence[float], MinLen(1)] = field(default_factory=list)

def __post_init__(self) -> None:
"""Post init function of _System class."""
Expand All @@ -170,26 +177,25 @@ def __post_init__(self) -> None:

# Set weights if they exist
if len(self.weights) == len(self.get_steps()):
[w / sum(self.weights) for w in self.weights]
self.weights = [w / sum(self.weights) for w in self.weights]
else:
num_steps = len(self.get_steps())
self.weights = [1 / num_steps for x in self.steps]
self.weights = [1 / num_steps] * num_steps

self._set_children(self.steps)

def get_steps(self) -> list[_Base]:
def get_steps(self) -> Annotated[Sequence[_Base], MinLen(1)]:
"""Return list of steps of _ParallelSystem.

:return: List of steps.
"""
if not self.steps:
return []
return self.steps

def get_weights(self) -> list[float]:
def get_weights(self) -> Annotated[Sequence[float], MinLen(1)]:
"""Return list of weights of _ParallelSystem.

:return: List of weights.
:raises TypeError: If mismatch between weights and steps.
"""
if len(self.get_steps()) != len(self.weights):
raise TypeError("Mismatch between weights and steps")
Expand Down Expand Up @@ -221,8 +227,7 @@ def _set_hash(self, prev_hash: str) -> None:

self._hash = hash(total)

@abstractmethod
def concat(self, original_data: Any, data_to_concat: Any, weight: float = 1.0) -> Any: # noqa: ANN401
def concat(self, original_data: _DT | None, data_to_concat: _DT, weight: float = 1.0) -> _DT:
"""Concatenate the transformed data.

:param original_data: The first input data.
Expand Down Expand Up @@ -255,7 +260,7 @@ def save_to_html(self, file_path: Path) -> None:
# Save html format to file_path
"""

steps: list[_Base] = field(default_factory=list)
steps: Annotated[Sequence[_Base], MinLen(1)] = field(default_factory=list)

def __post_init__(self) -> None:
"""Post init function of _System class."""
Expand All @@ -267,7 +272,7 @@ def __post_init__(self) -> None:

self._set_children(self.steps)

def get_steps(self) -> list[_Base]:
def get_steps(self) -> Sequence[_Base]:
"""Return list of steps of _ParallelSystem.

:return: List of steps.
Expand Down
Loading