Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: check-useless-excludes
# - id: identity # Prints all files passed to pre-commits. Debugging.
- repo: https://github.com/lyz-code/yamlfix
rev: 1.9.0
rev: 1.17.0
hooks:
- id: yamlfix
- repo: local
Expand Down Expand Up @@ -61,17 +61,17 @@ repos:
rev: 1.13.0
hooks:
- id: blacken-docs
# - repo: https://github.com/PyCQA/docformatter
# rev: v1.5.1
# hooks:
# - id: docformatter
# args:
# - --in-place
# - --wrap-summaries
# - '88'
# - --wrap-descriptions
# - '88'
# - --blank
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.7
hooks:
- id: docformatter
args:
- --in-place
- --wrap-summaries
- '88'
- --wrap-descriptions
- '88'
- --blank
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.261
hooks:
Expand All @@ -91,7 +91,7 @@ repos:
- id: nbqa-black
- id: nbqa-ruff
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.22
hooks:
- id: mdformat
additional_dependencies:
Expand Down
2 changes: 1 addition & 1 deletion src/tranquilo/aggregate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def aggregator_identity(vector_model):

"""
n_params = vector_model.linear_terms.size
intercept = float(vector_model.intercepts)
intercept = float(vector_model.intercepts[0])
linear_terms = vector_model.linear_terms.flatten()
if vector_model.square_terms is None:
square_terms = np.zeros((n_params, n_params))
Expand Down
187 changes: 187 additions & 0 deletions src/tranquilo/batch_evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""A collection of batch evaluators for process based parallelism.

All batch evaluators have the same interface and any function with the same interface
can be used used as batch evaluator.

"""

from joblib import Parallel, delayed

try:
from pathos.pools import ProcessPool

pathos_is_available = True
except ImportError:
pathos_is_available = False

from typing import Any, Callable, Literal, TypeVar

from tranquilo.config import DEFAULT_N_CORES as N_CORES
from tranquilo.decorators import catch, unpack
from tranquilo.options import ErrorHandling

T = TypeVar("T")


def pathos_mp_batch_evaluator(
func: Callable[..., T],
arguments: list[Any],
*,
n_cores: int = N_CORES,
error_handling: (
ErrorHandling | Literal["raise", "continue"]
) = ErrorHandling.CONTINUE,
unpack_symbol: Literal["*", "**"] | None = None,
) -> list[T]:
"""Batch evaluator based on pathos.multiprocess.ProcessPool.

This uses a patched but older version of python multiprocessing that replaces
pickling with dill and can thus handle decorated functions.

Args:
func (Callable): The function that is evaluated.
arguments (Iterable): Arguments for the functions. Their interperation
depends on the unpack argument.
n_cores (int): Number of cores used to evaluate the function in parallel.
Value below one are interpreted as one. If only one core is used, the
batch evaluator disables everything that could cause problems, i.e. in that
case func and arguments are never pickled and func is executed in the main
process.
error_handling (str): Can take the values "raise" (raise the error and stop all
tasks as soon as one task fails) and "continue" (catch exceptions and set
the traceback of the raised exception.
KeyboardInterrupt and SystemExit are always raised.
unpack_symbol (str or None). Can be "**", "*" or None. If None, func just takes
one argument. If "*", the elements of arguments are positional arguments for
func. If "**", the elements of arguments are keyword arguments for func.


Returns:
list: The function evaluations.

"""
if not pathos_is_available:
raise NotImplementedError(
"To use the pathos_mp_batch_evaluator, install pathos with "
"conda install -c conda-forge pathos."
)

_check_inputs(func, arguments, n_cores, error_handling, unpack_symbol)
n_cores = int(n_cores)

reraise = error_handling in [
"raise",
ErrorHandling.RAISE,
ErrorHandling.RAISE_STRICT,
]

@unpack(symbol=unpack_symbol)
@catch(default="__traceback__", reraise=reraise)
def internal_func(*args: Any, **kwargs: Any) -> T:
return func(*args, **kwargs)

if n_cores <= 1:
res = [internal_func(arg) for arg in arguments]
else:
p = ProcessPool(nodes=n_cores)
try:
res = p.map(internal_func, arguments)
except Exception as e:
p.terminate()
raise e

return res


def joblib_batch_evaluator(
func: Callable[..., T],
arguments: list[Any],
*,
n_cores: int = N_CORES,
error_handling: (
ErrorHandling | Literal["raise", "continue"]
) = ErrorHandling.CONTINUE,
unpack_symbol: Literal["*", "**"] | None = None,
) -> list[T]:
"""Batch evaluator based on joblib's Parallel.

Args:
func (Callable): The function that is evaluated.
arguments (Iterable): Arguments for the functions. Their interperation
depends on the unpack argument.
n_cores (int): Number of cores used to evaluate the function in parallel.
Value below one are interpreted as one. If only one core is used, the
batch evaluator disables everything that could cause problems, i.e. in that
case func and arguments are never pickled and func is executed in the main
process.
error_handling (str): Can take the values "raise" (raise the error and stop all
tasks as soon as one task fails) and "continue" (catch exceptions and set
the output of failed tasks to the traceback of the raised exception.
KeyboardInterrupt and SystemExit are always raised.
unpack_symbol (str or None). Can be "**", "*" or None. If None, func just takes
one argument. If "*", the elements of arguments are positional arguments for
func. If "**", the elements of arguments are keyword arguments for func.


Returns:
list: The function evaluations.

"""
_check_inputs(func, arguments, n_cores, error_handling, unpack_symbol)
n_cores = int(n_cores) if int(n_cores) >= 2 else 1

reraise = error_handling in [
"raise",
ErrorHandling.RAISE,
ErrorHandling.RAISE_STRICT,
]

@unpack(symbol=unpack_symbol)
@catch(default="__traceback__", reraise=reraise)
def internal_func(*args: Any, **kwargs: Any) -> T:
return func(*args, **kwargs)

if n_cores == 1:
res = [internal_func(arg) for arg in arguments]
else:
res = Parallel(n_jobs=n_cores)(delayed(internal_func)(arg) for arg in arguments)

return res


def _check_inputs(
func: Callable[..., T],
arguments: list[Any],
n_cores: int,
error_handling: ErrorHandling | Literal["raise", "continue"],
unpack_symbol: Literal["*", "**"] | None,
) -> None:
if not callable(func):
raise TypeError("func must be callable.")

try:
arguments = list(arguments)
except Exception as e:
raise ValueError("arguments must be list like.") from e

try:
int(n_cores)
except Exception as e:
raise ValueError("n_cores must be an integer.") from e

if unpack_symbol not in (None, "*", "**"):
raise ValueError(
f"unpack_symbol must be None, '*' or '**', not {unpack_symbol}"
)

if error_handling not in [
"raise",
"continue",
ErrorHandling.RAISE,
ErrorHandling.CONTINUE,
ErrorHandling.RAISE_STRICT,
]:
raise ValueError(
"error_handling must be 'raise' or 'continue' or ErrorHandling not "
f"{error_handling}"
)
12 changes: 12 additions & 0 deletions src/tranquilo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@
CRITERION_PENALTY_CONSTANT = 100


# ======================================================================================
# Check Available Packages
# ======================================================================================

try:
import optimagic # noqa: F401
except ImportError:
IS_OPTIMAGIC_INSTALLED = False
else:
IS_OPTIMAGIC_INSTALLED = True


# =================================================================================
# Dashboard Defaults
# =================================================================================
Expand Down
126 changes: 126 additions & 0 deletions src/tranquilo/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""This module contains various decorators.

There are two kinds of decorators defined in this module which consists of either two or
three nested functions. The former are decorators without and the latter with arguments.

For more information on decorators, see this `guide
`_ on https://realpython.com

which
provides a comprehensive overview.

.. _guide:

https://realpython.com/primer-on-python-decorators/

"""

import sys
from traceback import format_exception

import functools
import warnings


def catch(
func=None,
*,
exception=Exception,
exclude=(KeyboardInterrupt, SystemExit),
onerror=None,
default=None,
warn=True,
reraise=False,
):
"""Catch and handle exceptions.

This decorator can be used with and without additional arguments.

Args:
exception (Exception or tuple): One or several exceptions that
are caught and handled. By default all Exceptions are
caught and handled.
exclude (Exception or tuple): One or several exceptionts that
are not caught. By default those are KeyboardInterrupt and
SystemExit.
onerror (None or Callable): Callable that takes an Exception
as only argument. This is called when an exception occurs.
default: Value that is returned when as the output of func when
an exception occurs. Can be one of the following:
- a constant
- "__traceback__", in this case a string with a traceback is returned.
- callable with the same signature as func.
warn (bool): If True, the exception is converted to a warning.
reraise (bool): If True, the exception is raised after handling it.

"""

def decorator_catch(func):
@functools.wraps(func)
def wrapper_catch(*args, **kwargs):
try:
res = func(*args, **kwargs)
except exclude:
raise
except exception as e:
if onerror is not None:
onerror(e)

if reraise:
raise e

tb = get_traceback()

if warn:
msg = f"The following exception was caught:\n\n{tb}"
warnings.warn(msg)

if default == "__traceback__":
res = tb
elif callable(default):
res = default(*args, **kwargs)
else:
res = default
return res

return wrapper_catch

if callable(func):
return decorator_catch(func)
else:
return decorator_catch


def unpack(func=None, symbol=None):
def decorator_unpack(func):
if symbol is None:

@functools.wraps(func)
def wrapper_unpack(arg):
return func(arg)

elif symbol == "*":

@functools.wraps(func)
def wrapper_unpack(arg):
return func(*arg)

elif symbol == "**":

@functools.wraps(func)
def wrapper_unpack(arg):
return func(**arg)

return wrapper_unpack

if callable(func):
return decorator_unpack(func)
else:
return decorator_unpack


def get_traceback():
tb = format_exception(*sys.exc_info())
if isinstance(tb, list):
tb = "".join(tb)
return tb
Loading