Skip to content

Commit f1fb05e

Browse files
committed
planner + executor agent
1 parent 05c43de commit f1fb05e

File tree

5 files changed

+1768
-0
lines changed

5 files changed

+1768
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Planner + Executor Agent Examples
2+
3+
This directory contains examples for the `PlannerExecutorAgent`, a two-tier agent
4+
architecture with separate Planner (7B+) and Executor (3B-7B) models.
5+
6+
## Examples
7+
8+
| File | Description |
9+
|------|-------------|
10+
| `minimal_example.py` | Basic usage with OpenAI models |
11+
| `local_models_example.py` | Using local HuggingFace/MLX models |
12+
| `custom_config_example.py` | Custom configuration (escalation, retry, vision) |
13+
| `tracing_example.py` | Full tracing integration for Predicate Studio |
14+
15+
## Architecture
16+
17+
```
18+
┌─────────────────────────────────────────────────────────────┐
19+
│ PlannerExecutorAgent │
20+
├─────────────────────────────────────────────────────────────┤
21+
│ Planner (7B+) │ Executor (3B-7B) │
22+
│ ───────────── │ ──────────────── │
23+
│ • Generates JSON plan │ • Executes each step │
24+
│ • Includes predicates │ • Snapshot-first approach │
25+
│ • Handles replanning │ • Vision fallback │
26+
└─────────────────────────────────────────────────────────────┘
27+
28+
29+
┌─────────────────────────────────────────────────────────────┐
30+
│ AgentRuntime │
31+
│ • Snapshots with limit escalation │
32+
│ • Predicate verification │
33+
│ • Tracing for Studio visualization │
34+
└─────────────────────────────────────────────────────────────┘
35+
```
36+
37+
## Quick Start
38+
39+
```python
40+
from predicate.agents import PlannerExecutorAgent, PlannerExecutorConfig
41+
from predicate.llm_provider import OpenAIProvider
42+
from predicate import AsyncPredicateBrowser
43+
from predicate.agent_runtime import AgentRuntime
44+
45+
# Create LLM providers
46+
planner = OpenAIProvider(model="gpt-4o")
47+
executor = OpenAIProvider(model="gpt-4o-mini")
48+
49+
# Create agent
50+
agent = PlannerExecutorAgent(
51+
planner=planner,
52+
executor=executor,
53+
)
54+
55+
# Run task
56+
async with AsyncPredicateBrowser() as browser:
57+
page = await browser.new_page()
58+
await page.goto("https://example.com")
59+
60+
runtime = AgentRuntime.from_page(page)
61+
result = await agent.run(
62+
runtime=runtime,
63+
task="Find the main heading on this page",
64+
)
65+
print(f"Success: {result.success}")
66+
```
67+
68+
## Configuration
69+
70+
### Snapshot Escalation
71+
72+
Control how the agent increases snapshot limits when elements are missing:
73+
74+
```python
75+
from predicate.agents import SnapshotEscalationConfig
76+
77+
# Default: 60 -> 90 -> 120 -> 150 -> 180 -> 200
78+
config = PlannerExecutorConfig()
79+
80+
# Disable escalation (always use 60)
81+
config = PlannerExecutorConfig(
82+
snapshot=SnapshotEscalationConfig(enabled=False)
83+
)
84+
85+
# Custom step size: 60 -> 110 -> 160 -> 200
86+
config = PlannerExecutorConfig(
87+
snapshot=SnapshotEscalationConfig(limit_step=50)
88+
)
89+
```
90+
91+
### Retry Configuration
92+
93+
```python
94+
from predicate.agents import RetryConfig
95+
96+
config = PlannerExecutorConfig(
97+
retry=RetryConfig(
98+
verify_timeout_s=15.0, # Verification timeout
99+
verify_max_attempts=8, # Max verification attempts
100+
max_replans=2, # Max replanning attempts
101+
)
102+
)
103+
```
104+
105+
### Vision Fallback
106+
107+
```python
108+
from predicate.agents.browser_agent import VisionFallbackConfig
109+
110+
config = PlannerExecutorConfig(
111+
vision=VisionFallbackConfig(
112+
enabled=True,
113+
max_vision_calls=5,
114+
)
115+
)
116+
```
117+
118+
## Tracing for Predicate Studio
119+
120+
To visualize agent runs in Predicate Studio:
121+
122+
```python
123+
from predicate.tracer_factory import create_tracer
124+
125+
tracer = create_tracer(
126+
api_key="sk_...",
127+
upload_trace=True,
128+
goal="Search and add to cart",
129+
agent_type="PlannerExecutorAgent",
130+
)
131+
132+
agent = PlannerExecutorAgent(
133+
planner=planner,
134+
executor=executor,
135+
tracer=tracer, # Pass tracer for visualization
136+
)
137+
138+
# ... run agent ...
139+
140+
tracer.close() # Upload trace to Studio
141+
```
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
#!/usr/bin/env python3
2+
"""
3+
PlannerExecutorAgent example with local HuggingFace models.
4+
5+
This example demonstrates using local models instead of cloud APIs:
6+
- Planner: DeepSeek-R1-Distill-Qwen-14B (reasoning model)
7+
- Executor: Qwen2.5-7B-Instruct (fast instruction following)
8+
9+
Usage:
10+
export PREDICATE_API_KEY="sk_..." # Optional, for cloud browser
11+
python local_models_example.py
12+
13+
Requirements:
14+
pip install torch transformers accelerate
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import asyncio
20+
import os
21+
from dataclasses import dataclass
22+
23+
import torch
24+
from transformers import AutoModelForCausalLM, AutoTokenizer
25+
26+
from predicate import AsyncPredicateBrowser
27+
from predicate.agent_runtime import AgentRuntime
28+
from predicate.agents import (
29+
PlannerExecutorAgent,
30+
PlannerExecutorConfig,
31+
SnapshotEscalationConfig,
32+
)
33+
from predicate.backends.playwright_backend import PlaywrightBackend
34+
from predicate.llm_provider import LLMProvider, LLMResponse
35+
36+
37+
@dataclass
38+
class LocalHFProvider(LLMProvider):
39+
"""
40+
Local HuggingFace model provider.
41+
42+
Loads a model from HuggingFace and runs inference locally.
43+
"""
44+
45+
def __init__(
46+
self,
47+
model_name: str,
48+
device_map: str = "auto",
49+
torch_dtype: torch.dtype = torch.bfloat16,
50+
):
51+
super().__init__(model=model_name)
52+
self._model_name = model_name
53+
54+
print(f"Loading model: {model_name}...")
55+
self.tokenizer = AutoTokenizer.from_pretrained(
56+
model_name,
57+
trust_remote_code=True,
58+
)
59+
self.model = AutoModelForCausalLM.from_pretrained(
60+
model_name,
61+
device_map=device_map,
62+
torch_dtype=torch_dtype,
63+
trust_remote_code=True,
64+
low_cpu_mem_usage=True,
65+
)
66+
print(f"Model loaded: {model_name}")
67+
68+
def generate(
69+
self,
70+
system_prompt: str,
71+
user_prompt: str,
72+
**kwargs,
73+
) -> LLMResponse:
74+
messages = [
75+
{"role": "system", "content": system_prompt},
76+
{"role": "user", "content": user_prompt},
77+
]
78+
79+
text = self.tokenizer.apply_chat_template(
80+
messages,
81+
tokenize=False,
82+
add_generation_prompt=True,
83+
)
84+
85+
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
86+
prompt_tokens = inputs.input_ids.shape[1]
87+
88+
max_new_tokens = kwargs.get("max_new_tokens", 512)
89+
temperature = kwargs.get("temperature", 0.0)
90+
91+
with torch.no_grad():
92+
outputs = self.model.generate(
93+
**inputs,
94+
max_new_tokens=max_new_tokens,
95+
temperature=temperature if temperature > 0 else None,
96+
do_sample=temperature > 0,
97+
pad_token_id=self.tokenizer.eos_token_id,
98+
)
99+
100+
completion_tokens = outputs.shape[1] - prompt_tokens
101+
response_text = self.tokenizer.decode(
102+
outputs[0][prompt_tokens:],
103+
skip_special_tokens=True,
104+
)
105+
106+
return LLMResponse(
107+
content=response_text,
108+
model_name=self._model_name,
109+
prompt_tokens=prompt_tokens,
110+
completion_tokens=completion_tokens,
111+
total_tokens=prompt_tokens + completion_tokens,
112+
)
113+
114+
def supports_json_mode(self) -> bool:
115+
return False
116+
117+
@property
118+
def model_name(self) -> str:
119+
return self._model_name
120+
121+
122+
async def main() -> None:
123+
predicate_api_key = os.getenv("PREDICATE_API_KEY")
124+
125+
# Create local model providers
126+
# Use smaller models for demo; adjust based on your hardware
127+
planner_model = os.getenv(
128+
"PLANNER_MODEL",
129+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
130+
)
131+
executor_model = os.getenv(
132+
"EXECUTOR_MODEL",
133+
"Qwen/Qwen2.5-7B-Instruct",
134+
)
135+
136+
planner = LocalHFProvider(planner_model)
137+
executor = LocalHFProvider(executor_model)
138+
139+
# Create agent with custom config for local models
140+
config = PlannerExecutorConfig(
141+
# Slightly larger limits for local models
142+
snapshot=SnapshotEscalationConfig(
143+
limit_base=80,
144+
limit_step=40,
145+
limit_max=200,
146+
),
147+
# Longer timeouts for local inference
148+
planner_max_tokens=2048,
149+
executor_max_tokens=128,
150+
)
151+
152+
agent = PlannerExecutorAgent(
153+
planner=planner,
154+
executor=executor,
155+
config=config,
156+
)
157+
158+
# Simple task
159+
task = "Navigate to example.com and find the main heading"
160+
161+
print(f"Task: {task}")
162+
print(f"Planner: {planner_model}")
163+
print(f"Executor: {executor_model}")
164+
print("=" * 60)
165+
166+
async with AsyncPredicateBrowser(
167+
api_key=predicate_api_key,
168+
headless=False,
169+
) as browser:
170+
page = await browser.new_page()
171+
await page.goto("https://example.com")
172+
await page.wait_for_load_state("networkidle")
173+
174+
backend = PlaywrightBackend(page)
175+
runtime = AgentRuntime(backend=backend)
176+
177+
result = await agent.run(
178+
runtime=runtime,
179+
task=task,
180+
start_url="https://example.com",
181+
)
182+
183+
print("\n" + "=" * 60)
184+
print(f"Success: {result.success}")
185+
print(f"Steps: {result.steps_completed}/{result.steps_total}")
186+
print(f"Duration: {result.total_duration_ms}ms")
187+
188+
189+
if __name__ == "__main__":
190+
asyncio.run(main())

0 commit comments

Comments
 (0)