diff --git a/effectful/handlers/llm/synthesis.py b/effectful/handlers/llm/synthesis.py
index 3a77441b..8744fa54 100644
--- a/effectful/handlers/llm/synthesis.py
+++ b/effectful/handlers/llm/synthesis.py
@@ -1,10 +1,13 @@
import ast
import collections.abc
import dataclasses
+import inspect
import linecache
import re
import textwrap
import typing
+from collections.abc import Callable
+from typing import get_args, get_type_hints
from effectful.handlers.llm import Template
from effectful.ops.semantics import fwd
@@ -25,6 +28,48 @@ class ProgramSynthesis(ObjectInterpretation):
"""
+ def __init__(self, type_check: bool = False):
+ """Initialize the program synthesis handler.
+
+ Args:
+ type_check: Whether to verify the function signature matches the expected type.
+ """
+ self.type_check = type_check
+
+ def verify_callable_signature(self, func: Callable, expected_type: type) -> None:
+ """Verify that the function signature matches the expected type."""
+ type_args = get_args(expected_type)
+ if not type_args:
+ return
+
+ # For Callable[[P1, P2, ...], R], get_args returns ([P1, P2, ...], R)
+ # where the first element is a list of param types (or ... for Callable[..., R])
+ expected_param_types, expected_return = type_args[0], type_args[-1]
+
+ sig = inspect.signature(func)
+ actual_hints = get_type_hints(func)
+
+ # Verify the return type
+ actual_return = actual_hints.get("return", inspect.Parameter.empty)
+ if actual_return != expected_return:
+ raise SynthesisError(
+ f"Return type mismatch: expected {expected_return}, got {actual_return}"
+ )
+
+ # Verify the parameter types (if specified and not ellipsis)
+ if expected_param_types is not ... and expected_param_types:
+ params = list(sig.parameters.values())
+ if len(params) != len(expected_param_types):
+ raise SynthesisError(
+ f"Parameter count mismatch: expected {len(expected_param_types)}, got {len(params)}"
+ )
+ for param, expected in zip(params, expected_param_types):
+ actual = actual_hints.get(param.name, inspect.Parameter.empty)
+ if actual != expected:
+ raise SynthesisError(
+ f"Parameter {param.name} type mismatch: expected {expected}, got {actual}"
+ )
+
def _parse_and_eval[T](self, t: type[T], content: str) -> T:
pattern = r"(.*?)"
code_content = re.search(pattern, content, re.DOTALL)
@@ -51,23 +96,26 @@ def _parse_and_eval[T](self, t: type[T], content: str) -> T:
# register into linecache
linecache.cache[filename] = (len(source_code), None, lines, filename)
- # TODO: assert callable type compatibility
gs: dict = {}
try:
code_obj = compile(source_code, filename, "exec")
exec(code_obj, gs)
+
except Exception as exc:
- raise SynthesisError("evaluation failed", content) from exc
+ raise SynthesisError(f"evaluation failed: {exc}", content) from exc
+ if self.type_check:
+ self.verify_callable_signature(gs[last_decl.name], t)
+ # TODO: even more static analysis and type checking, adding type guards, etc.
return gs[last_decl.name]
@implements(Template.__call__)
def _call(self, template, *args, **kwargs) -> None:
ret_type = template.__signature__.return_annotation
origin = typing.get_origin(ret_type)
- ret_type = ret_type if origin is None else origin
+ ret_type_origin = ret_type if origin is None else origin
- if not (issubclass(ret_type, collections.abc.Callable)): # type: ignore[arg-type]
+ if not (issubclass(ret_type_origin, collections.abc.Callable)): # type: ignore[arg-type]
return fwd()
prompt_ext = textwrap.dedent(f"""
@@ -95,6 +143,7 @@ def _call(self, template, *args, **kwargs) -> None:
**kwargs,
)
+ # Pass the full ret_type (with type args) for type checking, not just the origin
functional = self._parse_and_eval(ret_type, response)
return functional
diff --git a/tests/test_handlers_llm.py b/tests/test_handlers_llm.py
index 2531ef0d..ff8c9199 100644
--- a/tests/test_handlers_llm.py
+++ b/tests/test_handlers_llm.py
@@ -1,7 +1,9 @@
from collections.abc import Callable
+import pytest
+
from effectful.handlers.llm import Template
-from effectful.handlers.llm.synthesis import ProgramSynthesis
+from effectful.handlers.llm.synthesis import ProgramSynthesis, SynthesisError
from effectful.ops.semantics import handler
from effectful.ops.syntax import ObjectInterpretation, implements
@@ -115,3 +117,26 @@ def count_occurrences(s):
assert callable(count_a)
assert count_a("banana") == 3
assert count_a("cherry") == 0
+
+
+def test_count_char_with_program_synthesis_type_check():
+ """Test the count_char template with program synthesis and type checking."""
+ mock_code = """
+def count_occurrences(s: str) -> int:
+ return s.count('a')
+"""
+ mock_provider = SingleResponseLLMProvider(mock_code)
+
+ mock_error = """def count_occurrences(s: str) -> None:
+ return None"""
+ mock_error_provider = SingleResponseLLMProvider(mock_error)
+
+ with handler(mock_provider), handler(ProgramSynthesis(type_check=True)):
+ count_a = count_char("a")
+ assert callable(count_a)
+ assert count_a("banana") == 3
+ assert count_a("cherry") == 0
+
+ with pytest.raises(SynthesisError):
+ with handler(mock_error_provider), handler(ProgramSynthesis(type_check=True)):
+ count_a = count_char("a")