Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions effectful/handlers/llm/synthesis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ast
import collections.abc
import dataclasses
import inspect
import linecache
import re
import textwrap
import typing
from collections.abc import Callable
from typing import get_args, get_type_hints

from effectful.handlers.llm import Template
from effectful.ops.semantics import fwd
Expand All @@ -25,6 +28,48 @@ class ProgramSynthesis(ObjectInterpretation):

"""

def __init__(self, type_check: bool = False):
"""Initialize the program synthesis handler.

Args:
type_check: Whether to verify the function signature matches the expected type.
"""
self.type_check = type_check

def verify_callable_signature(self, func: Callable, expected_type: type) -> None:
"""Verify that the function signature matches the expected type."""
type_args = get_args(expected_type)
if not type_args:
return

# For Callable[[P1, P2, ...], R], get_args returns ([P1, P2, ...], R)
# where the first element is a list of param types (or ... for Callable[..., R])
expected_param_types, expected_return = type_args[0], type_args[-1]

sig = inspect.signature(func)
actual_hints = get_type_hints(func)

# Verify the return type
actual_return = actual_hints.get("return", inspect.Parameter.empty)
if actual_return != expected_return:
raise SynthesisError(
f"Return type mismatch: expected {expected_return}, got {actual_return}"
)

# Verify the parameter types (if specified and not ellipsis)
if expected_param_types is not ... and expected_param_types:
params = list(sig.parameters.values())
if len(params) != len(expected_param_types):
raise SynthesisError(
f"Parameter count mismatch: expected {len(expected_param_types)}, got {len(params)}"
)
for param, expected in zip(params, expected_param_types):
actual = actual_hints.get(param.name, inspect.Parameter.empty)
if actual != expected:
raise SynthesisError(
f"Parameter {param.name} type mismatch: expected {expected}, got {actual}"
)

def _parse_and_eval[T](self, t: type[T], content: str) -> T:
pattern = r"<code>(.*?)</code>"
code_content = re.search(pattern, content, re.DOTALL)
Expand All @@ -51,23 +96,26 @@ def _parse_and_eval[T](self, t: type[T], content: str) -> T:
# register into linecache
linecache.cache[filename] = (len(source_code), None, lines, filename)

# TODO: assert callable type compatibility
gs: dict = {}
try:
code_obj = compile(source_code, filename, "exec")
exec(code_obj, gs)

except Exception as exc:
raise SynthesisError("evaluation failed", content) from exc
raise SynthesisError(f"evaluation failed: {exc}", content) from exc

if self.type_check:
self.verify_callable_signature(gs[last_decl.name], t)
# TODO: even more static analysis and type checking, adding type guards, etc.
return gs[last_decl.name]

@implements(Template.__call__)
def _call(self, template, *args, **kwargs) -> None:
ret_type = template.__signature__.return_annotation
origin = typing.get_origin(ret_type)
ret_type = ret_type if origin is None else origin
ret_type_origin = ret_type if origin is None else origin

if not (issubclass(ret_type, collections.abc.Callable)): # type: ignore[arg-type]
if not (issubclass(ret_type_origin, collections.abc.Callable)): # type: ignore[arg-type]
return fwd()

prompt_ext = textwrap.dedent(f"""
Expand Down Expand Up @@ -95,6 +143,7 @@ def _call(self, template, *args, **kwargs) -> None:
**kwargs,
)

# Pass the full ret_type (with type args) for type checking, not just the origin
functional = self._parse_and_eval(ret_type, response)

return functional
27 changes: 26 additions & 1 deletion tests/test_handlers_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections.abc import Callable

import pytest

from effectful.handlers.llm import Template
from effectful.handlers.llm.synthesis import ProgramSynthesis
from effectful.handlers.llm.synthesis import ProgramSynthesis, SynthesisError
from effectful.ops.semantics import handler
from effectful.ops.syntax import ObjectInterpretation, implements

Expand Down Expand Up @@ -115,3 +117,26 @@ def count_occurrences(s):
assert callable(count_a)
assert count_a("banana") == 3
assert count_a("cherry") == 0


def test_count_char_with_program_synthesis_type_check():
"""Test the count_char template with program synthesis and type checking."""
mock_code = """<code>
def count_occurrences(s: str) -> int:
return s.count('a')
</code>"""
mock_provider = SingleResponseLLMProvider(mock_code)

mock_error = """<code>def count_occurrences(s: str) -> None:
return None</code>"""
mock_error_provider = SingleResponseLLMProvider(mock_error)

with handler(mock_provider), handler(ProgramSynthesis(type_check=True)):
count_a = count_char("a")
assert callable(count_a)
assert count_a("banana") == 3
assert count_a("cherry") == 0

with pytest.raises(SynthesisError):
with handler(mock_error_provider), handler(ProgramSynthesis(type_check=True)):
count_a = count_char("a")
Loading