From 057cce46833cb743b9edc7b0139104334a8b2b7a Mon Sep 17 00:00:00 2001 From: datvo06 Date: Thu, 4 Dec 2025 11:21:30 -0500 Subject: [PATCH 1/2] Adding tests --- effectful/handlers/llm/synthesis.py | 56 ++++++++++++++++++++++++++--- tests/test_handlers_llm.py | 27 +++++++++++++- 2 files changed, 78 insertions(+), 5 deletions(-) diff --git a/effectful/handlers/llm/synthesis.py b/effectful/handlers/llm/synthesis.py index 3a77441b..4898de5d 100644 --- a/effectful/handlers/llm/synthesis.py +++ b/effectful/handlers/llm/synthesis.py @@ -1,10 +1,13 @@ import ast import collections.abc import dataclasses +import inspect import linecache import re import textwrap import typing +from collections.abc import Callable +from typing import get_args, get_type_hints from effectful.handlers.llm import Template from effectful.ops.semantics import fwd @@ -25,6 +28,48 @@ class ProgramSynthesis(ObjectInterpretation): """ + def __init__(self, type_check: bool = False): + """Initialize the program synthesis handler. + + Args: + type_check: Whether to verify the function signature matches the expected type. + """ + self.type_check = type_check + + def verify_callable_signature(self, func: Callable, expected_type: type) -> None: + """Verify that the function signature matches the expected type.""" + type_args = get_args(expected_type) + if not type_args: + return + + # For Callable[[P1, P2, ...], R], get_args returns ([P1, P2, ...], R) + # where the first element is a list of param types (or ... for Callable[..., R]) + expected_param_types, expected_return = type_args[0], type_args[-1] + + sig = inspect.signature(func) + actual_hints = get_type_hints(func) + + # Verify the return type + actual_return = actual_hints.get("return", inspect.Parameter.empty) + if actual_return != expected_return: + raise SynthesisError( + f"Return type mismatch: expected {expected_return}, got {actual_return}" + ) + + # Verify the parameter types (if specified and not ellipsis) + if expected_param_types is not ... and expected_param_types: + params = list(sig.parameters.values()) + if len(params) != len(expected_param_types): + raise SynthesisError( + f"Parameter count mismatch: expected {len(expected_param_types)}, got {len(params)}" + ) + for param, expected in zip(params, expected_param_types): + actual = actual_hints.get(param.name, inspect.Parameter.empty) + if actual != expected: + raise SynthesisError( + f"Parameter {param.name} type mismatch: expected {expected}, got {actual}" + ) + def _parse_and_eval[T](self, t: type[T], content: str) -> T: pattern = r"(.*?)" code_content = re.search(pattern, content, re.DOTALL) @@ -51,13 +96,15 @@ def _parse_and_eval[T](self, t: type[T], content: str) -> T: # register into linecache linecache.cache[filename] = (len(source_code), None, lines, filename) - # TODO: assert callable type compatibility gs: dict = {} try: code_obj = compile(source_code, filename, "exec") exec(code_obj, gs) + if self.type_check: + self.verify_callable_signature(gs[last_decl.name], t) + # TODO: even more static analysis and type checking, adding type guards, etc. except Exception as exc: - raise SynthesisError("evaluation failed", content) from exc + raise SynthesisError(f"evaluation failed: {exc}", content) from exc return gs[last_decl.name] @@ -65,9 +112,9 @@ def _parse_and_eval[T](self, t: type[T], content: str) -> T: def _call(self, template, *args, **kwargs) -> None: ret_type = template.__signature__.return_annotation origin = typing.get_origin(ret_type) - ret_type = ret_type if origin is None else origin + ret_type_origin = ret_type if origin is None else origin - if not (issubclass(ret_type, collections.abc.Callable)): # type: ignore[arg-type] + if not (issubclass(ret_type_origin, collections.abc.Callable)): # type: ignore[arg-type] return fwd() prompt_ext = textwrap.dedent(f""" @@ -95,6 +142,7 @@ def _call(self, template, *args, **kwargs) -> None: **kwargs, ) + # Pass the full ret_type (with type args) for type checking, not just the origin functional = self._parse_and_eval(ret_type, response) return functional diff --git a/tests/test_handlers_llm.py b/tests/test_handlers_llm.py index 2531ef0d..ff8c9199 100644 --- a/tests/test_handlers_llm.py +++ b/tests/test_handlers_llm.py @@ -1,7 +1,9 @@ from collections.abc import Callable +import pytest + from effectful.handlers.llm import Template -from effectful.handlers.llm.synthesis import ProgramSynthesis +from effectful.handlers.llm.synthesis import ProgramSynthesis, SynthesisError from effectful.ops.semantics import handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -115,3 +117,26 @@ def count_occurrences(s): assert callable(count_a) assert count_a("banana") == 3 assert count_a("cherry") == 0 + + +def test_count_char_with_program_synthesis_type_check(): + """Test the count_char template with program synthesis and type checking.""" + mock_code = """ +def count_occurrences(s: str) -> int: + return s.count('a') +""" + mock_provider = SingleResponseLLMProvider(mock_code) + + mock_error = """def count_occurrences(s: str) -> None: + return None""" + mock_error_provider = SingleResponseLLMProvider(mock_error) + + with handler(mock_provider), handler(ProgramSynthesis(type_check=True)): + count_a = count_char("a") + assert callable(count_a) + assert count_a("banana") == 3 + assert count_a("cherry") == 0 + + with pytest.raises(SynthesisError): + with handler(mock_error_provider), handler(ProgramSynthesis(type_check=True)): + count_a = count_char("a") From 705ac7dd69886853590eff6136665340a8245e6f Mon Sep 17 00:00:00 2001 From: datvo06 Date: Thu, 4 Dec 2025 11:27:14 -0500 Subject: [PATCH 2/2] Moving typecheck out of execution error --- effectful/handlers/llm/synthesis.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/effectful/handlers/llm/synthesis.py b/effectful/handlers/llm/synthesis.py index 4898de5d..8744fa54 100644 --- a/effectful/handlers/llm/synthesis.py +++ b/effectful/handlers/llm/synthesis.py @@ -100,12 +100,13 @@ def _parse_and_eval[T](self, t: type[T], content: str) -> T: try: code_obj = compile(source_code, filename, "exec") exec(code_obj, gs) - if self.type_check: - self.verify_callable_signature(gs[last_decl.name], t) - # TODO: even more static analysis and type checking, adding type guards, etc. + except Exception as exc: raise SynthesisError(f"evaluation failed: {exc}", content) from exc + if self.type_check: + self.verify_callable_signature(gs[last_decl.name], t) + # TODO: even more static analysis and type checking, adding type guards, etc. return gs[last_decl.name] @implements(Template.__call__)