From ea07294ce32872efdb73be5084c0990b1ed624e1 Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Mon, 24 Nov 2025 12:17:09 -0500 Subject: [PATCH 1/2] implement first to k-ahead sampler --- effectful/handlers/llm/sampling.py | 47 ++++++++++++++++++++++++++++++ effectful/internals/runtime.py | 3 +- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 effectful/handlers/llm/sampling.py diff --git a/effectful/handlers/llm/sampling.py b/effectful/handlers/llm/sampling.py new file mode 100644 index 00000000..effc197d --- /dev/null +++ b/effectful/handlers/llm/sampling.py @@ -0,0 +1,47 @@ +from collections import Counter +from concurrent import futures +from concurrent.futures.thread import ThreadPoolExecutor + +from effectful.handlers.llm import Template +from effectful.internals.runtime import get_interpretation, interpreter +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, implements + + +class KAheadSampler[**P, T](ObjectInterpretation): + no_voters: int + k: int + """Number of votes ahead before an answer is accepted""" + votes: Counter[T] = Counter() + + def __init__(self, no_voters: int = 6, k: int = 3): + self.no_voters = no_voters + self.k = k + + @implements(Template.__call__) + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + executor = ThreadPoolExecutor() + intp = get_interpretation() + tasks = [ + executor.submit(interpreter(intp)(fwd), *args, **kwargs) + for _ in range(self.no_voters) + ] + + def n_votes_ahead(): + match self.votes.most_common(2): + case [[_, v1], [_, v2]]: + return v1 >= v2 + self.k + case [[_, v1]]: + return v1 >= self.k + case _: + return False + + while not n_votes_ahead(): + done, remain = futures.wait(tasks, return_when=futures.FIRST_COMPLETED) + tasks = list(remain) + for fut in done: + res = fut.result() + self.votes[res] += 1 + tasks.append(executor.submit(interpreter(intp)(fwd), *args, **kwargs)) + executor.shutdown() + return self.votes.most_common(1)[0][0] diff --git a/effectful/internals/runtime.py b/effectful/internals/runtime.py index f99472fe..12488e13 100644 --- a/effectful/internals/runtime.py +++ b/effectful/internals/runtime.py @@ -2,13 +2,14 @@ import dataclasses import functools from collections.abc import Callable, Mapping +from threading import local from effectful.ops.syntax import defop from effectful.ops.types import Interpretation, Operation @dataclasses.dataclass -class Runtime[S, T]: +class Runtime[S, T](local): interpretation: "Interpretation[S, T]" From 795856e0b5d034bd6f2a247d526e9ffeed7b90ac Mon Sep 17 00:00:00 2001 From: Kiran Gopinathan Date: Mon, 24 Nov 2025 16:26:24 -0500 Subject: [PATCH 2/2] implemented a basic proof-search using effectful --- tests/test_proof_search.py | 265 +++++++++++++++++++++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 tests/test_proof_search.py diff --git a/tests/test_proof_search.py b/tests/test_proof_search.py new file mode 100644 index 00000000..bb101115 --- /dev/null +++ b/tests/test_proof_search.py @@ -0,0 +1,265 @@ +from collections import defaultdict, deque + +import pantograph as pg +from openai import OpenAI +from pantograph.expr import GoalState +from pantograph.message import ServerError +from pantograph.server import Server, Site, TacticFailure +from pydantic.dataclasses import dataclass + +from effectful.handlers.llm import Template +from effectful.handlers.llm.providers import OpenAIAPIProvider +from effectful.ops.semantics import handler +from effectful.ops.types import NotHandled + + +@dataclass(frozen=True) +class Induction: + target: str + + def to_string(self) -> str: + return f"induction {self.target}" + + +@dataclass(frozen=True) +class Simp: + lemmas: tuple[str, ...] | None = None + + def to_string(self) -> str: + return f"simp [{', '.join(self.lemmas)}]" if self.lemmas else "simp" + + +@dataclass(frozen=True) +class Rw: + lemmas: tuple[str, ...] + + def to_string(self) -> str: + return f"rw [{', '.join(self.lemmas)}]" + + +@dataclass(frozen=True) +class Grind: + def to_string(self) -> str: + return "grind" + + +Tactic = Induction | Simp | Rw | Grind + + +type StateId = int + + +@dataclass(frozen=True) +class SearchNode: + goal_state: GoalState + parent: "SearchNode | None" = None + tactic: Tactic | None = None + site: Site | None = None + + +class ProofSearcher: + def __init__( + self, server: Server, state: GoalState, predict_tactic_fn, context: str + ): + self.server = server + self.predict_tactic_fn = predict_tactic_fn + self.context = context + + self.work = deque([state.state_id]) + self.active = defaultdict(list) + self.active[state.state_id].append(SearchNode(state)) + + self.done: list[SearchNode] = [] + self.failed: list[tuple[SearchNode, str]] = [] + + # --- small utilities ----------------------------------------------------- + + def _rename_tactic(self, gs: GoalState, idx: int) -> str: + vs = gs.goals[idx].variables + bad = [v for v in vs if "✝" in (v.name or "")] + if not bad: + return "" + names = [v.name.replace("✝", "") for v in bad if v.name] + return f"rename_i {' '.join(names)}" + + def _normalize(self, gs: GoalState, idx: int) -> GoalState: + t = self._rename_tactic(gs, idx) + if not t: + return gs + try: + return self.server.goal_tactic(gs, tactic=t, site=Site(goal_id=idx)) + except (ServerError, TacticFailure): + return gs + + # --- core --------------------------------------------------------------- + + def step(self) -> bool: + if not self.work: + return False + + sid = self.work.popleft() + nodes = self.active.pop(sid, []) + progressed = False + + for node in nodes: + tactic = self._predict(node) + nxt = self._apply(node.goal_state, tactic, node.site) + + if isinstance(nxt, str): + self.failed.append((node, nxt)) + self.active[sid].append(node) + continue + + child = SearchNode(nxt, parent=node.parent, tactic=tactic, site=node.site) + + if not nxt.goals: + self.done.append(child) + continue + + progressed = True + for i in range(len(nxt.goals)): + gs = self._normalize(nxt, i) + n = SearchNode(gs, parent=child, site=Site(i)) + self.active[gs.state_id].append(n) + self.work.append(gs.state_id) + + if not progressed: + self.work.append(sid) + + return bool(self.done) + + # --- proof extraction ---------------------------------------------------- + + def extract(self, node: SearchNode | None) -> list[SearchNode]: + out = [] + while node: + out.append(node) + node = node.parent + return out[::-1] + + def render(self, node: SearchNode) -> str: + steps = self.extract(node) + out = [] + indent = "" + cur = None + + for s in steps: + if s.site and s.site.goal_id < len(s.goal_state.goals): + r = self._rename_tactic(s.goal_state, s.site.goal_id) + if r: + out.append(indent + r) + + name = s.goal_state.goals[s.site.goal_id].name + if name and name != cur: + out.append(f"{indent}case {name} =>") + indent = " " + cur = name + + if s.tactic: + out.append(indent + s.tactic.to_string()) + + return self.context + "\n".join(out) + + # --- prediction ---------------------------------------------------------- + + def _predict(self, node: SearchNode) -> Tactic: + if ( + node.site + and node.site.goal_id + and node.site.goal_id < len(node.goal_state.goals) + ): + g = node.goal_state.goals[node.site.goal_id] + else: + g = node.goal_state.goals[0] + + vars = [TypeBinding(v.name, v.t) for v in g.variables if v.name] + script = self.render(node) + t = self.predict_tactic_fn(str(g), vars, script) + + print("at proof script:\n" + script) + print("predicted:", t.to_string()) + return t + + # --- server wrapper ------------------------------------------------------ + + def _apply( + self, state: GoalState, tactic: Tactic, site: Site | None + ) -> GoalState | str: + try: + return self.server.goal_tactic( + state, tactic=tactic.to_string(), site=site or Site() + ) + except (ServerError, TacticFailure) as e: + return str(e) + + +@dataclass +class TypeBinding: + name: str + t: str + + def __str__(self): + return f"{self.name}: {self.t}" + + +@Template.define +def predict_tactic( + goal_state: str, variables: list[TypeBinding], proof_script: str +) -> Tactic: + """ + You are an experienced proof engineer, working for the Lean FRO. You are proficient with the internals of the Lean theorem prover. + You are currently working on a proof. This proof certifies mission critical software, and completing it will save engineers $200 worth of time. + + You have written the following proof script: + + {proof_script} + + The current goal state is: + + {goal_state} + + You have access to the following variables: + + {variables} + + You must predict a tactic between: + + - induction + - simp (optionally specify which lemmas to provide) + - rw [] + - grind + + Take a deep breath, think carefully, and predict the next tactic to perform. You can do it. + """ + raise NotHandled + + +defs = """ +def sum_upto (i: Nat) (j: Nat) (f: Nat → Nat) : Nat := match j with +| 0 => f i +| j + 1 => sum_upto i j f + f (j + 1) + +notation "∑_{" i " ← " a "}^{" b "} " f => sum_upto a b (fun i => f) +""" +thm_stmt = """theorem sum_upto_mul_two' (n : Nat) : (∑_{i ← 0}^{n} i) = n * (n + 1) / 2 := by\n""" + +server = pg.Server(imports=["Init"]) +server.load_definitions(defs) + +[ + thm, +] = server.load_sorry(thm_stmt + " sorry") + +searcher = ProofSearcher(server, thm.goal_state, predict_tactic, defs + thm_stmt) + +with handler(OpenAIAPIProvider(OpenAI())): + + def step(): + searcher.step() + + progress = True + + while not step(): + pass + +searcher.render(searcher.done[0])