diff --git a/effectful/handlers/llm/synthesis.py b/effectful/handlers/llm/synthesis.py index 3a77441b..8744fa54 100644 --- a/effectful/handlers/llm/synthesis.py +++ b/effectful/handlers/llm/synthesis.py @@ -1,10 +1,13 @@ import ast import collections.abc import dataclasses +import inspect import linecache import re import textwrap import typing +from collections.abc import Callable +from typing import get_args, get_type_hints from effectful.handlers.llm import Template from effectful.ops.semantics import fwd @@ -25,6 +28,48 @@ class ProgramSynthesis(ObjectInterpretation): """ + def __init__(self, type_check: bool = False): + """Initialize the program synthesis handler. + + Args: + type_check: Whether to verify the function signature matches the expected type. + """ + self.type_check = type_check + + def verify_callable_signature(self, func: Callable, expected_type: type) -> None: + """Verify that the function signature matches the expected type.""" + type_args = get_args(expected_type) + if not type_args: + return + + # For Callable[[P1, P2, ...], R], get_args returns ([P1, P2, ...], R) + # where the first element is a list of param types (or ... for Callable[..., R]) + expected_param_types, expected_return = type_args[0], type_args[-1] + + sig = inspect.signature(func) + actual_hints = get_type_hints(func) + + # Verify the return type + actual_return = actual_hints.get("return", inspect.Parameter.empty) + if actual_return != expected_return: + raise SynthesisError( + f"Return type mismatch: expected {expected_return}, got {actual_return}" + ) + + # Verify the parameter types (if specified and not ellipsis) + if expected_param_types is not ... and expected_param_types: + params = list(sig.parameters.values()) + if len(params) != len(expected_param_types): + raise SynthesisError( + f"Parameter count mismatch: expected {len(expected_param_types)}, got {len(params)}" + ) + for param, expected in zip(params, expected_param_types): + actual = actual_hints.get(param.name, inspect.Parameter.empty) + if actual != expected: + raise SynthesisError( + f"Parameter {param.name} type mismatch: expected {expected}, got {actual}" + ) + def _parse_and_eval[T](self, t: type[T], content: str) -> T: pattern = r"(.*?)" code_content = re.search(pattern, content, re.DOTALL) @@ -51,23 +96,26 @@ def _parse_and_eval[T](self, t: type[T], content: str) -> T: # register into linecache linecache.cache[filename] = (len(source_code), None, lines, filename) - # TODO: assert callable type compatibility gs: dict = {} try: code_obj = compile(source_code, filename, "exec") exec(code_obj, gs) + except Exception as exc: - raise SynthesisError("evaluation failed", content) from exc + raise SynthesisError(f"evaluation failed: {exc}", content) from exc + if self.type_check: + self.verify_callable_signature(gs[last_decl.name], t) + # TODO: even more static analysis and type checking, adding type guards, etc. return gs[last_decl.name] @implements(Template.__call__) def _call(self, template, *args, **kwargs) -> None: ret_type = template.__signature__.return_annotation origin = typing.get_origin(ret_type) - ret_type = ret_type if origin is None else origin + ret_type_origin = ret_type if origin is None else origin - if not (issubclass(ret_type, collections.abc.Callable)): # type: ignore[arg-type] + if not (issubclass(ret_type_origin, collections.abc.Callable)): # type: ignore[arg-type] return fwd() prompt_ext = textwrap.dedent(f""" @@ -95,6 +143,7 @@ def _call(self, template, *args, **kwargs) -> None: **kwargs, ) + # Pass the full ret_type (with type args) for type checking, not just the origin functional = self._parse_and_eval(ret_type, response) return functional diff --git a/tests/test_handlers_llm.py b/tests/test_handlers_llm.py index 2531ef0d..ff8c9199 100644 --- a/tests/test_handlers_llm.py +++ b/tests/test_handlers_llm.py @@ -1,7 +1,9 @@ from collections.abc import Callable +import pytest + from effectful.handlers.llm import Template -from effectful.handlers.llm.synthesis import ProgramSynthesis +from effectful.handlers.llm.synthesis import ProgramSynthesis, SynthesisError from effectful.ops.semantics import handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -115,3 +117,26 @@ def count_occurrences(s): assert callable(count_a) assert count_a("banana") == 3 assert count_a("cherry") == 0 + + +def test_count_char_with_program_synthesis_type_check(): + """Test the count_char template with program synthesis and type checking.""" + mock_code = """ +def count_occurrences(s: str) -> int: + return s.count('a') +""" + mock_provider = SingleResponseLLMProvider(mock_code) + + mock_error = """def count_occurrences(s: str) -> None: + return None""" + mock_error_provider = SingleResponseLLMProvider(mock_error) + + with handler(mock_provider), handler(ProgramSynthesis(type_check=True)): + count_a = count_char("a") + assert callable(count_a) + assert count_a("banana") == 3 + assert count_a("cherry") == 0 + + with pytest.raises(SynthesisError): + with handler(mock_error_provider), handler(ProgramSynthesis(type_check=True)): + count_a = count_char("a")