diff --git a/effectful/handlers/futures/__init__.py b/effectful/handlers/futures/__init__.py new file mode 100644 index 00000000..625d5d4b --- /dev/null +++ b/effectful/handlers/futures/__init__.py @@ -0,0 +1,249 @@ +""" +Futures handler for effectful - provides integration with concurrent.futures. + +This module provides operations for working with concurrent.futures, allowing +effectful operations to be executed asynchronously in thread pools with +automatic preservation of interpretation context. +""" + +import concurrent.futures as futures +import functools +from collections.abc import Callable, Iterable +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Literal + +from effectful.ops.semantics import defop +from effectful.ops.syntax import ObjectInterpretation, defdata, implements +from effectful.ops.types import NotHandled, Term + + +class Executor: + """Namespace for executor-related operations.""" + + @staticmethod + @defop # type: ignore + def submit[**P, T]( + task: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> Future[T]: + """ + Submit a task for asynchronous execution. + + This operation should be handled by providing a FuturesInterpretation + which automatically preserves the interpretation context across thread boundaries. + + :param task: The callable to execute asynchronously + :param args: Positional arguments for the task + :param kwargs: Keyword arguments for the task + :return: A Future representing the asynchronous computation + + Example: + >>> from concurrent.futures import ThreadPoolExecutor + >>> from effectful.handlers.futures import ThreadPoolFuturesInterpretation + >>> from effectful.ops.semantics import handler + >>> + >>> with handler(ThreadPoolFuturesInterpretation()): + >>> future = Executor.submit(lambda x,y: x + y, 1, 2) + """ + raise NotHandled + + @staticmethod + @defop + def map[T, R]( + func: Callable[[T], R], + *iterables: Iterable[T], + timeout: float | None = None, + chunksize: int = 1, + ) -> Iterable[R]: + """ + Map a function over iterables, executing asynchronously. + + Returns an iterator yielding results as they complete. Equivalent to + map(func, *iterables) but executes asynchronously. + + This operation should be handled by providing a FuturesInterpretation + which automatically preserves the interpretation context across thread boundaries. + + :param func: The function to map over the iterables + :param iterables: One or more iterables to map over + :param timeout: Maximum time to wait for a result (default: None) + :param chunksize: Size of chunks for ProcessPoolExecutor (default: 1) + :return: An iterator yielding results + + Example: + >>> from effectful.handlers.futures import ThreadPoolFuturesInterpretation + >>> from effectful.ops.semantics import handler + >>> + >>> def square(x): + >>> return x ** 2 + >>> + >>> with handler(ThreadPoolFuturesInterpretation()): + >>> results = list(Executor.map(square, range(10))) + >>> print(results) # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + """ + raise NotHandled + + +class FuturesInterpretation(ObjectInterpretation): + """ + Base interpretation for concurrent.futures executors. + + This interpretation automatically preserves the effectful interpretation context + when submitting tasks to worker threads, ensuring that effectful operations + work correctly across thread boundaries. + """ + + def __init__(self, executor: futures.Executor): + """ + Initialize the futures interpretation. + + :param executor: The executor to use (ThreadPoolExecutor or ProcessPoolExecutor) + """ + super().__init__() + self.executor: futures.Executor = executor + + def shutdown(self, *args, **kwargs): + self.executor.shutdown(*args, **kwargs) + + @implements(Executor.submit) + def submit(self, task: Callable, *args, **kwargs) -> Future: + """ + Submit a task to the executor with automatic context preservation. + + Captures the current interpretation context and ensures it is restored + in the worker thread before executing the task. + """ + from effectful.internals.runtime import get_interpretation, interpreter + + # Capture the current interpretation context + context = get_interpretation() + + # Submit the wrapped task to the underlying executor + return self.executor.submit(interpreter(context)(task), *args, **kwargs) + + @implements(Executor.map) + def map(self, func: Callable, *iterables, timeout=None, chunksize=1): + """ + Map a function over iterables with automatic context preservation. + + Captures the current interpretation context and ensures it is restored + in each worker thread before executing the function. + """ + from effectful.internals.runtime import get_interpretation, interpreter + + # Capture the current interpretation context + context = get_interpretation() + + @functools.wraps(func) + def wrapped_func(*args, **kwargs): + # Restore the interpretation context in the worker thread + with interpreter(context): + return func(*args, **kwargs) + + # Call the executor's map with the wrapped function + return self.executor.map( + wrapped_func, *iterables, timeout=timeout, chunksize=chunksize + ) + + +class ThreadPoolFuturesInterpretation(FuturesInterpretation): + """ + Interpretation for ThreadPoolExecutor with automatic context preservation. + + Example: + >>> from concurrent.futures import ThreadPoolExecutor, Future + >>> from effectful.ops.syntax import defop + >>> from effectful.ops.semantics import handler + >>> from effectful.handlers.futures import Executor, ThreadPoolFuturesInterpretation + >>> + >>> @defop + >>> def pow(n: int, k: int) -> Future[int]: + >>> return Executor.submit(pow, n, k) + >>> + >>> pool = ThreadPoolExecutor() + >>> with handler(ThreadPoolFuturesInterpretation(pool)): + >>> result = pow(2, 10).result() + >>> print(result) # 1024 + """ + + def __init__(self, *args, **kwargs): + """ + Initialize with a ThreadPoolExecutor. + + :param max_workers: Maximum number of worker threads (default: None, uses default from ThreadPoolExecutor) + """ + super().__init__(ThreadPoolExecutor(*args, **kwargs)) + + +type ReturnOptions = Literal["All_COMPLETED", "FIRST_COMPLETED", "FIRST_EXCEPTION"] + + +@dataclass(frozen=True) +class DoneAndNotDoneFutures[T]: + done: set[Future[T]] + not_done: set[Future[T]] + + +@defdata.register(DoneAndNotDoneFutures) +class _DoneAndNotDoneFuturesTerm[T](Term[DoneAndNotDoneFutures[T]]): + """Term representing a DoneAndNotDoneFutures result.""" + + def __init__(self, op, *args, **kwargs): + self._op = op + self._args = args + self._kwargs = kwargs + + @property + def op(self): + return self._op + + @property + def args(self): + return self._args + + @property + def kwargs(self): + return self._kwargs + + @defop # type: ignore[prop-decorator] + @property + def done(self) -> set[Future[T]]: + """Get the set of done futures.""" + if not isinstance(self, Term): + return self.done + else: + raise NotHandled + + @defop # type: ignore[prop-decorator] + @property + def not_done(self) -> set[Future[T]]: + """Get the set of not done futures.""" + if not isinstance(self, Term): + return self.not_done + else: + raise NotHandled + + +@defop +def wait[T]( + fs: Iterable[Future[T]], + timeout: int | None = None, + return_when: ReturnOptions = futures.ALL_COMPLETED, # type: ignore +) -> DoneAndNotDoneFutures[T]: + if ( + isinstance(timeout, Term) + or isinstance(return_when, Term) + or any(not isinstance(t, Future) for t in fs) + ): + raise NotHandled + return futures.wait(fs, timeout, return_when) # type: ignore + + +@defop +def as_completed[T]( + fs: Iterable[Future[T]], + timeout: int | None = None, +) -> Iterable[Future[T]]: + if isinstance(timeout, Term) or any(isinstance(t, Term) for t in fs): + raise NotHandled + return futures.as_completed(fs, timeout) diff --git a/effectful/handlers/llm/providers.py b/effectful/handlers/llm/providers.py index b0b38ed5..ee33aca1 100644 --- a/effectful/handlers/llm/providers.py +++ b/effectful/handlers/llm/providers.py @@ -282,13 +282,13 @@ def __init__(self, client: openai.OpenAI, model_name: str = "gpt-4o"): self._client = client self._model_name = model_name - @implements(Template.__call__) - def _call[**P, T]( - self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs - ) -> T: - ret_type = template.__signature__.return_annotation - bound_args = template.__signature__.bind(*args, **kwargs) - bound_args.apply_defaults() + def _openai_api_call[**P, T, RT]( + self, + template: Template[P, T], + bound_args: inspect.BoundArguments, + ret_type: type[RT], + ) -> RT: + """Execute the actual OpenAI API call and decode the response.""" prompt = _OpenAIPromptFormatter().format_as_messages( template.__prompt_template__, **bound_args.arguments ) @@ -368,3 +368,12 @@ def _call[**P, T]( result = Result.model_validate_json(result_str) assert isinstance(result, Result) return result.value # type: ignore[attr-defined] + + @implements(Template.__call__) + def _call[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + ret_type = template.__signature__.return_annotation + bound_args = template.__signature__.bind(*args, **kwargs) + bound_args.apply_defaults() + return self._openai_api_call(template, bound_args, ret_type) diff --git a/effectful/internals/runtime.py b/effectful/internals/runtime.py index f99472fe..a954f36d 100644 --- a/effectful/internals/runtime.py +++ b/effectful/internals/runtime.py @@ -1,20 +1,25 @@ import contextlib -import dataclasses import functools +import threading from collections.abc import Callable, Mapping from effectful.ops.syntax import defop from effectful.ops.types import Interpretation, Operation -@dataclasses.dataclass -class Runtime[S, T]: +class Runtime[S, T](threading.local): + """Thread-local runtime for effectful interpretations.""" + interpretation: "Interpretation[S, T]" + def __init__(self): + super().__init__() + self.interpretation = {} + @functools.lru_cache(maxsize=1) def get_runtime() -> Runtime: - return Runtime(interpretation={}) + return Runtime() def get_interpretation(): diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index cc3857a7..b99c356f 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -19,6 +19,10 @@ def apply[**P, T](op: Operation[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """Apply ``op`` to ``args``, ``kwargs`` in interpretation ``intp``. + Handler execution is mutually exclusive by default - only one thread can + execute handlers at a time. Use :func:`release_handler_lock` to temporarily + release the lock for I/O operations that should allow concurrent execution. + Handling :func:`apply` changes the evaluation strategy of terms. **Example usage**: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index fc2d753c..62a08c84 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -9,6 +9,7 @@ import typing import warnings from collections.abc import Callable, Iterable, Mapping +from concurrent.futures import Future from typing import Annotated, Any, Concatenate from effectful.ops.types import Annotation, Expr, NotHandled, Operation, Term @@ -1673,3 +1674,67 @@ class _IntegralTerm[T: numbers.Integral](_RationalTerm[T]): @defdata.register(bool) class _BoolTerm[T: bool](_IntegralTerm[T]): # type: ignore pass + + +# Future support +@defdata.register(Future) +class _FutureTerm[T](_BaseTerm[Future[T]]): + """Term representing a Future computation.""" + + @defop + def result(self: Future[T], timeout: float | None = None) -> T: + """Get the result of the future.""" + if not isinstance(self, Term): + return self.result(timeout=timeout) + else: + raise NotHandled + + @defop + def exception( + self: Future[T], timeout: float | None = None + ) -> BaseException | None: + """Get the exception from the future, if any.""" + if not isinstance(self, Term): + return self.exception(timeout=timeout) + else: + raise NotHandled + + @defop + def cancel(self: Future[T]) -> bool: + """Attempt to cancel the future.""" + if not isinstance(self, Term): + return self.cancel() + else: + raise NotHandled + + @defop + def cancelled(self: Future[T]) -> bool: + """Check if the future was cancelled.""" + if not isinstance(self, Term): + return self.cancelled() + else: + raise NotHandled + + @defop + def done(self: Future[T]) -> bool: + """Check if the future is done.""" + if not isinstance(self, Term): + return self.done() + else: + raise NotHandled + + @defop + def running(self: Future[T]) -> bool: + """Check if the future is currently running.""" + if not isinstance(self, Term): + return self.running() + else: + raise NotHandled + + @defop + def add_done_callback(self: Future[T], fn: Callable[[Future[T]], None]) -> None: + """Add a callback to be called when the future completes.""" + if not isinstance(self, Term): + return self.add_done_callback(fn) + else: + raise NotHandled diff --git a/tests/test_handlers_futures.py b/tests/test_handlers_futures.py new file mode 100644 index 00000000..f1df7992 --- /dev/null +++ b/tests/test_handlers_futures.py @@ -0,0 +1,216 @@ +""" +Tests for the futures handler (effectful.handlers.futures). + +This module tests the integration of concurrent.futures with effectful, +including context preservation across thread boundaries. +""" + +import time +from concurrent.futures import Future +from threading import RLock + +import effectful.handlers.futures as futures +from effectful.handlers.futures import ( + Executor, + ThreadPoolFuturesInterpretation, +) +from effectful.ops.semantics import NotHandled, defop, evaluate, handler +from effectful.ops.types import Term + + +@defop +def add(x: int, y: int) -> int: + raise NotHandled + + +@defop +def a_mul(x: int, y: int) -> Future[int]: + raise NotHandled + + +@defop +def a_div(x: int, y: int) -> Future[int]: + raise NotHandled + + +@defop +def a_fac(n: int) -> Future[int]: + raise NotHandled + + +def test_uninterp_async(): + """calling async func without interpretation returns term""" + t = a_div(10, 20) + assert isinstance(t, Term) + + +def test_mutual_exclusion(): + """Handler execution is not mutually exclusive by default, just + like any other object call. As in python, if you call a function + that may have some shared state, you must lock it as a client. + + Without mutual exclusion, the race condition in add_interp would cause + add_calls to be less than 10. With mutual exclusion, we're guaranteed + to get exactly 10 calls. + + """ + add_calls = 0 + + def add_interp(x: int, y: int) -> int: + nonlocal add_calls + no_calls = add_calls + time.sleep(0.001) + add_calls = no_calls + 1 + return x + y + + client_lock = RLock() + + def client(x: int): + # hey, I'm running a function that may have shared state, let me lock it + with client_lock: + res = add(x, x) + return res + + with ( + handler(ThreadPoolFuturesInterpretation(max_workers=4)), + handler({add: add_interp}), + ): + _ = sum(Executor.map(client, list(range(10)))) + # With mutual exclusion, we're guaranteed to get exactly 10 + assert add_calls == 10 + + +def test_concurrent_client_execution(): + add_calls = 0 + add_calls_interp = 0 + + def add_interp(x: int, y: int) -> int: + nonlocal add_calls + no_calls = add_calls + time.sleep(0.001) + add_calls = no_calls + 1 + return x + y + + def client(x: int): + # clients submitted to the executor ARE NOT synchronous + nonlocal add_calls_interp + no_calls = add_calls_interp + time.sleep(0.001) + add_calls_interp = no_calls + 1 + return add(x, x) + + with ( + handler(ThreadPoolFuturesInterpretation(max_workers=4)), + handler({add: add_interp}), + ): + _ = sum(Executor.map(client, list(range(10)))) + # Without mutual exclusion, we're not guaranteed to get exactly 10 + assert add_calls != 10 + # client is not synchronous so no guarantees. + assert add_calls_interp != 10 + + +def test_wait_several_futures(): + def client_code(): + results = [] + for fut in futures.wait([a_div(3, 4), a_mul(4, 5)]).done: + results.append(fut.result()) # noqa: PERF401 + return results + + def a_div_interp(x, y): + return Executor.submit(lambda x, y: x / y, x, y) + + def a_mul_interp(x, y): + return Executor.submit(lambda x, y: x * y, x, y) + + with ( + handler(ThreadPoolFuturesInterpretation()), + handler({a_div: a_div_interp, a_mul: a_mul_interp}), + ): + assert set(client_code()) == {3 / 4, 4 * 5} + + +def test_eval_of_concurrent_terms(): + def client_code(): + # spawn two tasks in parallel + r1 = a_div(3, 4) + r2 = a_mul(3, 4) + return r1.result() + r2.result() + + def a_div_interp(x, y): + return Executor.submit(lambda x, y: x / y, x, y) + + def a_mul_interp(x, y): + return Executor.submit(lambda x, y: x * y, x, y) + + res_stx = client_code() + assert isinstance(res_stx, Term) + + with ( + handler(ThreadPoolFuturesInterpretation()), + handler({a_div: a_div_interp, a_mul: a_mul_interp}), + ): + res = client_code() + assert res == (3 / 4 + 3 * 4) + res = evaluate(res) + assert res == (3 / 4 + 3 * 4) + + +def test_context_captured_at_submission(): + def submit_work(): + return Executor.submit(lambda: add(3, 4)) + + def add_interp(x, y): + return x + y + + def add_as_mul_interp(x, y): + return x * y + + with handler(ThreadPoolFuturesInterpretation()): + with handler({add: add_interp}): + future = submit_work() + + # Retrieve result in a different context + with handler({add: add_as_mul_interp}): + result = future.result() + + # The result should be 7 (from submission context), not 12 + assert result == 7 + + # Also test retrieving result completely outside any interpretation + with ( + handler(ThreadPoolFuturesInterpretation()), + handler({add: add_interp}), + ): + future = submit_work() + + # Retrieve result outside the handler context entirely + result = future.result() + assert result == 7 + + +def test_concurrent_execution_faster_than_sequential(): + sleep_duration = 0.001 # 50ms per task + + def add_with_sleep(x, y): + # important: we must release lock here to allow concurrency + start = time.time() + time.sleep(sleep_duration) + return time.time() - start + + with ( + handler(ThreadPoolFuturesInterpretation(max_workers=3)), + handler({add: add_with_sleep}), + ): + start = time.time() + + # Submit three tasks concurrently + f1 = Executor.submit(lambda: add(1, 2)) + f2 = Executor.submit(lambda: add(3, 4)) + f3 = Executor.submit(lambda: add(5, 6)) + + # Get all results + sequential_time = f1.result() + f2.result() + f3.result() + elapsed = time.time() - start + + assert elapsed < sequential_time diff --git a/tests/test_handlers_llm_futures.py b/tests/test_handlers_llm_futures.py new file mode 100644 index 00000000..4b7491a6 --- /dev/null +++ b/tests/test_handlers_llm_futures.py @@ -0,0 +1,102 @@ +""" +Tests for the LLM handler Future support. + +This module tests that LLM templates with Future[T] return types +correctly submit work concurrently and decode using the inner type. +""" + +import time +from collections.abc import Callable +from concurrent.futures import Future +from inspect import BoundArguments +from typing import Any, override + +import effectful.handlers.futures as futures +from effectful.handlers.futures import Executor, ThreadPoolFuturesInterpretation +from effectful.handlers.llm import Template +from effectful.handlers.llm.providers import OpenAIAPIProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + + +class SlowMockLLMProvider(OpenAIAPIProvider): + """Mock provider that simulates slow LLM responses for testing concurrency.""" + + def __init__(self, response, delay: float = 0.05, mapping={}): + self.response = response + self.delay = delay + self.calls: list[tuple[Any, tuple[Any], type]] = [] + self.mapping = mapping + + @override + def _openai_api_call[T]( + self, template: Template, args: BoundArguments, retty: type[T] + ) -> T: + self.calls.append((template, args.args, retty)) + time.sleep(self.delay) + return self.mapping.get(template, {}).get(tuple(args.args), self.response) + + +@Template.define +def hiaku(topic: str) -> str: + """Return a hiaku about {topic}.""" + raise NotHandled + + +def test_future_return_type_decodes_inner_type(): + """Test that llm templates correctly decode to int, even wrapped in a future.""" + ref_hiaku = "apples to oranges, oranges to pears, I don't know what a hiaku is" + mock_provider = SlowMockLLMProvider(ref_hiaku, delay=0.001) + + with handler(ThreadPoolFuturesInterpretation()), handler(mock_provider): + future = Executor.submit(hiaku, "apples") + assert isinstance(future, Future) + result = future.result() + assert result == ref_hiaku + + +@Template.define +def generate_program(task: str) -> Callable[[int], int]: + """Generate a Python program that {task}.""" + raise NotHandled + + +def test_concurrent_program_generation(): + """Simulate concurrent LLM calls to generate Python programs and pick the best one.""" + # Mock responses for different approaches to the same task + responses = { + generate_program: { + ("implement fibonacci algorithm 0",): "def fib(n: int) -> int: return n", + ( + "implement fibonacci algorithm 1", + ): "def fib(n: int) -> int: return n * fib(n - 1)", + ( + "implement fibonacci algorithm 2", + ): "def fib(n: int) -> int: return fib(n - 2) + fib(n - 1) if n > 1 else 0", + } + } + + mock_provider = SlowMockLLMProvider( + response="print('Default')", delay=0.01, mapping=responses + ) + + user_request: str = "implement fibonacci algorithm" + + with handler(ThreadPoolFuturesInterpretation()), handler(mock_provider): + # Launch multiple LLM calls concurrently + tasks = [ + Executor.submit(generate_program, (user_request + f" {i}")) + for i in range(3) + ] + + # Collect all results as they finish + results_as_completed = (f.result() for f in futures.as_completed(tasks)) + + valid_results = [(result, len(result)) for result in results_as_completed] + + # Pick the "best" result (here: the shortest program, as a naive heuristic) + best_program = max(valid_results, key=lambda pair: pair[1])[0] + + # Assertions + assert len(valid_results) == 3 + assert best_program in set(responses[generate_program].values())