diff --git a/openadapt_evals/benchmarks/cli.py b/openadapt_evals/benchmarks/cli.py index 3f1d604..a1eeccd 100644 --- a/openadapt_evals/benchmarks/cli.py +++ b/openadapt_evals/benchmarks/cli.py @@ -460,6 +460,18 @@ def cmd_run(args: argparse.Namespace) -> int: if use_controller: print(f"Using DemoController (max_retries={args.max_retries}, max_replans={args.max_replans})") + # Set up correction store if requested + correction_store = None + enable_correction_capture = getattr(args, "enable_correction_capture", False) + correction_library_path = getattr(args, "correction_library", None) + if correction_library_path: + from openadapt_evals.correction_store import CorrectionStore + + correction_store = CorrectionStore(correction_library_path) + print(f"Correction library: {correction_library_path}") + if enable_correction_capture: + print("Correction capture: ENABLED (will prompt for human corrections on failure)") + # Run evaluation if use_controller: from openadapt_evals.demo_controller import run_with_controller @@ -475,6 +487,8 @@ def cmd_run(args: argparse.Namespace) -> int: max_steps=args.max_steps, max_retries=args.max_retries, max_replans=args.max_replans, + correction_store=correction_store, + enable_correction_capture=enable_correction_capture, ) results.append(result) else: @@ -2432,6 +2446,10 @@ def main() -> int: run_parser.add_argument("--focus-check-method", type=str, default="win32", choices=["win32", "a11y", "both"], help="Method for foreground window check: win32 (fast, default), a11y, or both") + run_parser.add_argument("--correction-library", type=str, default=None, + help="Path to correction library directory for the correction flywheel") + run_parser.add_argument("--enable-correction-capture", action="store_true", + help="Enable HITL correction capture when agent fails (requires --correction-library)") # Live evaluation (full control) live_parser = subparsers.add_parser("live", help="Run live evaluation against WAA server (full control)") diff --git a/openadapt_evals/correction_capture.py b/openadapt_evals/correction_capture.py new file mode 100644 index 0000000..e2b5fc3 --- /dev/null +++ b/openadapt_evals/correction_capture.py @@ -0,0 +1,238 @@ +"""Correction capture for the correction flywheel. + +Captures a human correction using openadapt-capture's Recorder (primary path) +or falls back to simple periodic screenshots via PIL if openadapt-capture is +not available. + +The Recorder provides full input event recording (mouse + keyboard) plus +action-gated screenshots, which gives the VLM parser much richer context +for understanding what the human did. +""" + +from __future__ import annotations + +import logging +import os +import time +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class CorrectionResult: + """Result of a correction capture session.""" + + screenshots: list[str] = field(default_factory=list) # paths + capture_dir: str | None = None # openadapt-capture directory (if used) + duration_seconds: float = 0.0 + output_dir: str = "" + + +def _take_screenshot(output_path: str) -> str | None: + """Take a screenshot and save to output_path. Returns path or None.""" + try: + from PIL import ImageGrab + + img = ImageGrab.grab() + img.save(output_path) + return output_path + except Exception as exc: + logger.warning("Screenshot failed: %s", exc) + return None + + +def _has_recorder() -> bool: + """Check if openadapt-capture Recorder is available.""" + try: + from openadapt_capture.recorder import Recorder # noqa: F401 + + return True + except ImportError: + return False + + +def _prompt_user(step_desc: str, explanation: str) -> None: + """Print the correction prompt to the terminal.""" + print("\n" + "=" * 60) + print("CORRECTION NEEDED") + print("=" * 60) + print(f"Failed step: {step_desc}") + if explanation: + print(f"Reason: {explanation}") + print("\nPlease complete this step manually.") + print("Press Enter when done...") + print("=" * 60 + "\n") + + +def _wait_for_enter(timeout_seconds: int) -> None: + """Block until user presses Enter or timeout expires.""" + try: + import select + import sys + + if hasattr(select, "select"): + remaining = timeout_seconds + while remaining > 0: + ready, _, _ = select.select([sys.stdin], [], [], 1.0) + if ready: + sys.stdin.readline() + break + remaining -= 1.0 + else: + input() + except EOFError: + logger.info("stdin closed, stopping capture after timeout") + time.sleep(min(timeout_seconds, 10)) + + +class CorrectionCapture: + """Capture a human correction for a failed step.""" + + def __init__(self, output_dir: str): + self.output_dir = output_dir + os.makedirs(output_dir, exist_ok=True) + + def capture_correction( + self, + failure_context: dict, + timeout_seconds: int = 120, + interval_seconds: float = 2.0, + ) -> CorrectionResult: + """Capture a human correction. + + Uses openadapt-capture Recorder if available (full input events + + action-gated screenshots), otherwise falls back to periodic PIL + screenshots. + """ + # Save the failure screenshot as "before" + before_path = os.path.join(self.output_dir, "before.png") + before_screenshots = [] + if failure_context.get("screenshot_bytes"): + with open(before_path, "wb") as f: + f.write(failure_context["screenshot_bytes"]) + before_screenshots.append(before_path) + elif failure_context.get("screenshot_path"): + before_screenshots.append(failure_context["screenshot_path"]) + + step_desc = failure_context.get("step_action", "this step") + explanation = failure_context.get("explanation", "") + + _prompt_user(step_desc, explanation) + + if _has_recorder(): + return self._capture_with_recorder( + before_screenshots, timeout_seconds + ) + else: + logger.info("openadapt-capture not available, using simple screenshot capture") + return self._capture_simple( + before_screenshots, timeout_seconds, interval_seconds + ) + + def _capture_with_recorder( + self, + before_screenshots: list[str], + timeout_seconds: int, + ) -> CorrectionResult: + """Full capture using openadapt-capture Recorder.""" + from openadapt_capture.recorder import Recorder + + capture_dir = os.path.join(self.output_dir, "recording") + start = time.monotonic() + + with Recorder( + capture_dir, + task_description="Human correction for failed agent step", + capture_video=False, # screenshots only, faster + capture_audio=False, + ) as recorder: + recorder.wait_for_ready(timeout=30) + _wait_for_enter(timeout_seconds) + recorder.stop() + + duration = time.monotonic() - start + + # Extract screenshots from the capture + screenshot_paths = list(before_screenshots) + try: + from openadapt_capture.capture import CaptureSession + + session = CaptureSession.load(capture_dir) + for i, action in enumerate(session.actions()): + if action.screenshot is not None: + path = os.path.join(self.output_dir, f"action_{i:04d}.png") + action.screenshot.save(path) + screenshot_paths.append(path) + except Exception as exc: + logger.warning("Failed to extract screenshots from capture: %s", exc) + # Fall back to taking a final screenshot + after_path = os.path.join(self.output_dir, "after.png") + taken = _take_screenshot(after_path) + if taken: + screenshot_paths.append(taken) + + logger.info( + "Recorder capture complete: %d screenshots in %.1fs", + len(screenshot_paths), + duration, + ) + return CorrectionResult( + screenshots=screenshot_paths, + capture_dir=capture_dir, + duration_seconds=duration, + output_dir=self.output_dir, + ) + + def _capture_simple( + self, + before_screenshots: list[str], + timeout_seconds: int, + interval_seconds: float, + ) -> CorrectionResult: + """Fallback: periodic PIL screenshots.""" + import threading + + start = time.monotonic() + stop_event = threading.Event() + screenshot_paths: list[str] = [] + + def _capture_loop(): + idx = 0 + while not stop_event.is_set(): + stop_event.wait(interval_seconds) + if stop_event.is_set(): + break + path = os.path.join(self.output_dir, f"capture_{idx:04d}.png") + taken = _take_screenshot(path) + if taken: + screenshot_paths.append(taken) + idx += 1 + + capture_thread = threading.Thread(target=_capture_loop, daemon=True) + capture_thread.start() + + _wait_for_enter(timeout_seconds) + + stop_event.set() + capture_thread.join(timeout=5) + + # Final "after" screenshot + after_path = os.path.join(self.output_dir, "after.png") + taken = _take_screenshot(after_path) + if taken: + screenshot_paths.append(taken) + + all_screenshots = list(before_screenshots) + screenshot_paths + duration = time.monotonic() - start + + logger.info( + "Simple capture complete: %d screenshots in %.1fs", + len(all_screenshots), + duration, + ) + return CorrectionResult( + screenshots=all_screenshots, + duration_seconds=duration, + output_dir=self.output_dir, + ) diff --git a/openadapt_evals/correction_parser.py b/openadapt_evals/correction_parser.py new file mode 100644 index 0000000..f4b2220 --- /dev/null +++ b/openadapt_evals/correction_parser.py @@ -0,0 +1,86 @@ +"""Parse a human correction capture into a PlanStep. + +Uses a VLM call to compare before/after screenshots and describe what +the human did in the same format as a plan step (think/action/expect). +""" + +from __future__ import annotations + +import json +import logging +import os + +from openadapt_evals.vlm import vlm_call + +logger = logging.getLogger(__name__) + +_PARSE_PROMPT = """\ +The agent was trying to perform a step but failed. A human then completed the step manually. + +Failed step description: {step_action} +Failure explanation: {failure_explanation} + +Compare the BEFORE screenshot (when the agent failed) and the AFTER screenshot \ +(after the human completed the step). Describe what the human did to complete the step. + +Respond in this exact JSON format: +{{ + "think": "reasoning about what needed to happen and why the agent failed", + "action": "concrete description of what the human did (e.g., 'Click the Display button in the left sidebar')", + "expect": "what the screen looks like after the action" +}} + +Respond with ONLY the JSON object, no other text.""" + + +def parse_correction( + step_action: str, + failure_explanation: str, + before_screenshot: bytes, + after_screenshot: bytes, + model: str = "gpt-4.1-mini", + provider: str = "openai", +) -> dict: + """Parse before/after screenshots into a PlanStep dict. + + Returns dict with keys: think, action, expect. + """ + prompt = _PARSE_PROMPT.format( + step_action=step_action, + failure_explanation=failure_explanation, + ) + + response = vlm_call( + prompt, + images=[before_screenshot, after_screenshot], + model=model, + provider=provider, + max_tokens=512, + ) + + # Extract JSON from response + try: + # Try direct parse first + result = json.loads(response) + except json.JSONDecodeError: + # Try to find JSON in the response + import re + + match = re.search(r"\{[^}]+\}", response, re.DOTALL) + if match: + result = json.loads(match.group()) + else: + logger.error("Failed to parse VLM response as JSON: %s", response[:200]) + result = { + "think": f"Human corrected the step: {step_action}", + "action": step_action, + "expect": "Step completed successfully", + } + + # Ensure required keys exist + for key in ("think", "action", "expect"): + if key not in result: + result[key] = "" + + logger.info("Parsed correction: action=%s", result["action"][:80]) + return result diff --git a/openadapt_evals/correction_store.py b/openadapt_evals/correction_store.py new file mode 100644 index 0000000..520f675 --- /dev/null +++ b/openadapt_evals/correction_store.py @@ -0,0 +1,97 @@ +"""JSON-file-based correction library for the correction flywheel. + +Stores corrections as individual JSON files in a directory. Retrieval uses +exact task_id match + fuzzy string similarity on step descriptions. +""" + +from __future__ import annotations + +import difflib +import json +import logging +import os +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) + + +@dataclass +class CorrectionEntry: + """A single stored correction.""" + + task_id: str + step_description: str # original step action text + failure_screenshot_path: str + failure_explanation: str + correction_step: dict # PlanStep as dict (think/action/expect) + timestamp: str = "" # ISO format + run_id: str = "" + entry_id: str = "" + + def __post_init__(self): + if not self.timestamp: + self.timestamp = datetime.now(timezone.utc).isoformat() + if not self.entry_id: + self.entry_id = uuid.uuid4().hex[:12] + + +class CorrectionStore: + """Manages a directory of correction JSON files.""" + + def __init__(self, library_dir: str = "correction_library"): + self.library_dir = library_dir + os.makedirs(library_dir, exist_ok=True) + + def save(self, entry: CorrectionEntry) -> str: + """Save correction, return entry ID.""" + path = os.path.join(self.library_dir, f"{entry.entry_id}.json") + with open(path, "w") as f: + json.dump(asdict(entry), f, indent=2) + logger.info("Saved correction %s for task %s", entry.entry_id, entry.task_id) + return entry.entry_id + + def find( + self, + task_id: str, + step_description: str, + top_k: int = 3, + threshold: float = 0.6, + ) -> list[CorrectionEntry]: + """Find matching corrections by task_id + fuzzy step description match.""" + all_entries = self.load_all() + + # Filter to matching task_id + candidates = [e for e in all_entries if e.task_id == task_id] + if not candidates: + return [] + + # Score by string similarity on step_description + scored = [] + for entry in candidates: + ratio = difflib.SequenceMatcher( + None, step_description.lower(), entry.step_description.lower() + ).ratio() + if ratio >= threshold: + scored.append((ratio, entry)) + + scored.sort(key=lambda x: x[0], reverse=True) + return [entry for _, entry in scored[:top_k]] + + def load_all(self) -> list[CorrectionEntry]: + """Load all corrections from the library directory.""" + entries = [] + if not os.path.isdir(self.library_dir): + return entries + for fname in os.listdir(self.library_dir): + if not fname.endswith(".json"): + continue + path = os.path.join(self.library_dir, fname) + try: + with open(path) as f: + data = json.load(f) + entries.append(CorrectionEntry(**data)) + except (json.JSONDecodeError, TypeError, KeyError) as exc: + logger.warning("Skipping invalid correction file %s: %s", fname, exc) + return entries diff --git a/openadapt_evals/demo_controller.py b/openadapt_evals/demo_controller.py index 345ba8a..124739e 100644 --- a/openadapt_evals/demo_controller.py +++ b/openadapt_evals/demo_controller.py @@ -51,6 +51,7 @@ BenchmarkTask, ) from openadapt_evals.agents.base import BenchmarkAgent +from openadapt_evals.correction_store import CorrectionEntry, CorrectionStore from openadapt_evals.agents.claude_computer_use_agent import _parse_multilevel_demo from openadapt_evals.plan_verify import ( VerificationResult, @@ -143,6 +144,8 @@ def __init__( max_replans: int = 2, verify_model: str = "gpt-4.1-mini", verify_provider: str = "openai", + correction_store: CorrectionStore | None = None, + enable_correction_capture: bool = False, ) -> None: self.agent = agent self.adapter = adapter @@ -151,6 +154,8 @@ def __init__( self.max_replans = max_replans self.verify_model = verify_model self.verify_provider = verify_provider + self.correction_store = correction_store + self.enable_correction_capture = enable_correction_capture # Parse the demo into a structured plan self.plan_state = self._parse_demo(demo_text) @@ -386,6 +391,16 @@ def execute( current.status = "done" self._advance() elif current.attempts >= self.max_retries: + # Check correction library for a stored fix + if self._try_stored_correction(task, current): + continue # re-execute with injected correction + + # Capture human correction if enabled + if self._try_capture_correction( + task, current, obs, screenshot_bytes + ): + continue # correction captured, step marked done + logger.warning( "Step %d failed after %d attempts (last: %s); %s", current.step_num, @@ -819,6 +834,132 @@ def _build_step_verification_summary(self) -> str: return "\n".join(lines) + # ------------------------------------------------------------------ + # Correction flywheel + # ------------------------------------------------------------------ + + def _try_stored_correction( + self, task: BenchmarkTask, current: PlanStep + ) -> bool: + """Check correction library for a stored fix. Returns True if injected.""" + if not self.correction_store: + return False + + corrections = self.correction_store.find( + task_id=task.task_id, + step_description=current.action, + ) + if not corrections: + return False + + correction = corrections[0] + logger.info( + "Found stored correction for step %d (match: %s)", + current.step_num, + correction.entry_id, + ) + + # Inject correction as replacement step + corrected = PlanStep( + step_num=current.step_num, + think=correction.correction_step.get("think", current.think), + action=correction.correction_step.get("action", current.action), + expect=correction.correction_step.get("expect", current.expect), + status="in_progress", + attempts=0, + ) + self.plan_state.steps[self.plan_state.current_step_idx] = corrected + return True + + def _try_capture_correction( + self, + task: BenchmarkTask, + current: PlanStep, + obs: BenchmarkObservation, + screenshot_bytes: bytes | None, + ) -> bool: + """Capture human correction if enabled. Returns True if captured.""" + if not self.enable_correction_capture or not self.correction_store: + return False + + entry = self._capture_human_correction(task, current, obs, screenshot_bytes) + if entry is None: + return False + + self.correction_store.save(entry) + + # Mark step as done and advance + current.status = "done" + self._advance() + return True + + def _capture_human_correction( + self, + task: BenchmarkTask, + failed_step: PlanStep, + observation: BenchmarkObservation, + screenshot_bytes: bytes | None, + ) -> CorrectionEntry | None: + """Activate correction capture, wait for human, parse result.""" + import os + import tempfile + import uuid + + from openadapt_evals.correction_capture import CorrectionCapture + from openadapt_evals.correction_parser import parse_correction + + run_id = uuid.uuid4().hex[:8] + capture_dir = os.path.join( + tempfile.gettempdir(), + f"correction_{task.task_id}_{failed_step.step_num}_{run_id}", + ) + + failure_explanation = "" + if failed_step.verification_result: + failure_explanation = failed_step.verification_result.explanation + + capture = CorrectionCapture(output_dir=capture_dir) + result = capture.capture_correction( + failure_context={ + "screenshot_bytes": screenshot_bytes, + "step_action": failed_step.action, + "explanation": failure_explanation, + }, + ) + + if len(result.screenshots) < 2: + logger.warning("Correction capture got fewer than 2 screenshots; skipping") + return None + + # Read before and after screenshots + try: + with open(result.screenshots[0], "rb") as f: + before_bytes = f.read() + with open(result.screenshots[-1], "rb") as f: + after_bytes = f.read() + except (FileNotFoundError, OSError) as exc: + logger.error("Failed to read correction screenshots: %s", exc) + return None + + # Parse correction via VLM + correction_step = parse_correction( + step_action=failed_step.action, + failure_explanation=failure_explanation, + before_screenshot=before_bytes, + after_screenshot=after_bytes, + model=self.verify_model, + provider=self.verify_provider, + ) + + return CorrectionEntry( + task_id=task.task_id, + step_description=failed_step.action, + failure_screenshot_path=result.screenshots[0], + failure_explanation=failure_explanation, + correction_step=correction_step, + run_id=run_id, + ) + # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @@ -868,6 +1009,8 @@ def run_with_controller( max_replans: int = 2, verify_model: str = "gpt-4.1-mini", verify_provider: str = "openai", + correction_store: CorrectionStore | None = None, + enable_correction_capture: bool = False, ) -> BenchmarkResult: """Run a task using the demo-conditioned controller. @@ -884,6 +1027,8 @@ def run_with_controller( max_replans: Maximum replans of the remaining plan. verify_model: VLM model for verification. verify_provider: VLM provider for verification. + correction_store: Optional CorrectionStore for retrieval/storage. + enable_correction_capture: Whether to capture human corrections. Returns: A BenchmarkResult with the execution outcome. @@ -896,5 +1041,7 @@ def run_with_controller( max_replans=max_replans, verify_model=verify_model, verify_provider=verify_provider, + correction_store=correction_store, + enable_correction_capture=enable_correction_capture, ) return controller.execute(task, max_steps=max_steps) diff --git a/tests/test_correction_flywheel.py b/tests/test_correction_flywheel.py new file mode 100644 index 0000000..e2cd21b --- /dev/null +++ b/tests/test_correction_flywheel.py @@ -0,0 +1,518 @@ +"""Tests for the correction flywheel: store, capture, parser, and controller integration.""" + +from __future__ import annotations + +import json +import os +import tempfile +from dataclasses import asdict +from unittest.mock import MagicMock, patch + +import pytest + +from openadapt_evals.adapters.base import ( + BenchmarkAction, + BenchmarkObservation, + BenchmarkResult, + BenchmarkTask, +) +from openadapt_evals.correction_store import CorrectionEntry, CorrectionStore +from openadapt_evals.demo_controller import DemoController, PlanState, PlanStep +from openadapt_evals.plan_verify import VerificationResult + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +SAMPLE_DEMO = """\ +GOAL: Change the display resolution + +PLAN: +1. Open Display settings +2. Change resolution + +REFERENCE TRAJECTORY (for disambiguation -- adapt actions to your actual screen): + +Step 1: + Think: I need to open Display settings. + Action: Click the Display button in the left sidebar. + Expect: Display settings pane should open. + +Step 2: + Think: I need to change resolution. + Action: Click the Resolution dropdown and select 1920x1080. + Expect: Resolution should change to 1920x1080. +""" + + +def _make_entry( + task_id: str = "display-resolution", + step_desc: str = "Click the Display button in the left sidebar.", + **kwargs, +) -> CorrectionEntry: + defaults = { + "task_id": task_id, + "step_description": step_desc, + "failure_screenshot_path": "/tmp/before.png", + "failure_explanation": "Agent clicked wrong pane", + "correction_step": { + "think": "The Display button is in the left sidebar, third item", + "action": "Click the Display button (third item in left sidebar)", + "expect": "Display settings pane opens", + }, + "run_id": "test-run-01", + } + defaults.update(kwargs) + return CorrectionEntry(**defaults) + + +# --------------------------------------------------------------------------- +# CorrectionStore tests +# --------------------------------------------------------------------------- + + +class TestCorrectionStore: + def test_save_and_find(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + entry = _make_entry() + entry_id = store.save(entry) + + results = store.find( + task_id="display-resolution", + step_description="Click the Display button in the left sidebar.", + ) + assert len(results) == 1 + assert results[0].entry_id == entry_id + assert results[0].correction_step["action"] == entry.correction_step["action"] + + def test_fuzzy_match(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + store.save(_make_entry()) + + # Slightly different description should still match + results = store.find( + task_id="display-resolution", + step_description="Click Display button on the left sidebar", + ) + assert len(results) == 1 + + def test_no_match_different_task(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + store.save(_make_entry(task_id="display-resolution")) + + results = store.find( + task_id="change-wallpaper", + step_description="Click the Display button in the left sidebar.", + ) + assert len(results) == 0 + + def test_no_match_low_similarity(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + store.save(_make_entry()) + + results = store.find( + task_id="display-resolution", + step_description="Open the terminal and type ls", + ) + assert len(results) == 0 + + def test_load_all(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + store.save(_make_entry(entry_id="aaa")) + store.save(_make_entry(entry_id="bbb")) + + all_entries = store.load_all() + assert len(all_entries) == 2 + ids = {e.entry_id for e in all_entries} + assert ids == {"aaa", "bbb"} + + def test_entry_serialization_roundtrip(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + entry = _make_entry() + store.save(entry) + + loaded = store.load_all() + assert len(loaded) == 1 + assert asdict(loaded[0]) == asdict(entry) + + def test_empty_store(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + assert store.load_all() == [] + assert store.find("any", "any") == [] + + def test_multiple_matches_ranked(self, tmp_path): + store = CorrectionStore(str(tmp_path / "corrections")) + + # High similarity + store.save( + _make_entry( + step_desc="Click the Display button in the left sidebar.", + entry_id="exact", + ) + ) + # Medium similarity + store.save( + _make_entry( + step_desc="Click Display in sidebar", + entry_id="partial", + ) + ) + + results = store.find( + task_id="display-resolution", + step_description="Click the Display button in the left sidebar.", + top_k=2, + ) + assert len(results) == 2 + # Exact match should come first + assert results[0].entry_id == "exact" + + def test_skips_invalid_json(self, tmp_path): + lib_dir = tmp_path / "corrections" + lib_dir.mkdir() + (lib_dir / "bad.json").write_text("not json") + (lib_dir / "readme.txt").write_text("ignore me") + + store = CorrectionStore(str(lib_dir)) + assert store.load_all() == [] + + +# --------------------------------------------------------------------------- +# CorrectionParser tests +# --------------------------------------------------------------------------- + + +class TestCorrectionParser: + @patch("openadapt_evals.correction_parser.vlm_call") + def test_parse_correction_returns_plan_step(self, mock_vlm): + mock_vlm.return_value = json.dumps( + { + "think": "The Display button is third in the sidebar", + "action": "Click the third item in the left sidebar labeled Display", + "expect": "Display settings pane opens showing resolution options", + } + ) + + from openadapt_evals.correction_parser import parse_correction + + result = parse_correction( + step_action="Click the Display button", + failure_explanation="Clicked wrong button", + before_screenshot=b"fake-png-before", + after_screenshot=b"fake-png-after", + ) + + assert result["think"] == "The Display button is third in the sidebar" + assert "Display" in result["action"] + assert "expect" in result + mock_vlm.assert_called_once() + + @patch("openadapt_evals.correction_parser.vlm_call") + def test_parse_correction_handles_bad_json(self, mock_vlm): + mock_vlm.return_value = "Sorry, I can't parse that." + + from openadapt_evals.correction_parser import parse_correction + + result = parse_correction( + step_action="Click Display", + failure_explanation="Failed", + before_screenshot=b"fake", + after_screenshot=b"fake", + ) + + # Should fall back to reasonable defaults + assert "action" in result + assert "think" in result + assert "expect" in result + + +# --------------------------------------------------------------------------- +# DemoController correction integration tests +# --------------------------------------------------------------------------- + + +def _make_mock_agent(): + agent = MagicMock() + agent._external_step_control = False + agent.act.return_value = BenchmarkAction( + type="click", x=100, y=200, raw_action={} + ) + return agent + + +def _make_mock_adapter(): + adapter = MagicMock() + adapter.observe.return_value = BenchmarkObservation( + screenshot=b"fake-screenshot-bytes", + raw_observation={}, + ) + adapter.step.return_value = ( + BenchmarkObservation(screenshot=b"fake-screenshot-bytes", raw_observation={}), + False, # env_done + {}, # info + ) + adapter.evaluate.return_value = BenchmarkResult( + task_id="test", success=True, score=1.0 + ) + return adapter + + +class TestDemoControllerCorrections: + @patch("openadapt_evals.demo_controller.verify_step") + def test_uses_stored_correction(self, mock_verify, tmp_path): + """Controller retrieves and uses stored correction instead of replanning.""" + store = CorrectionStore(str(tmp_path / "corrections")) + store.save( + _make_entry( + task_id="test-task", + step_desc="Click the Display button in the left sidebar.", + ) + ) + + agent = _make_mock_agent() + adapter = _make_mock_adapter() + + controller = DemoController( + agent=agent, + adapter=adapter, + demo_text=SAMPLE_DEMO, + max_retries=1, + correction_store=store, + ) + + # First call: fail verification (triggers correction lookup) + # Second call (after correction injected): pass verification + mock_verify.side_effect = [ + VerificationResult( + status="not_verified", + confidence=0.2, + explanation="Wrong pane clicked", + raw_response="", + ), + VerificationResult( + status="verified", + confidence=0.95, + explanation="Display settings visible", + raw_response="", + ), + # Step 2 passes immediately + VerificationResult( + status="verified", + confidence=0.9, + explanation="Resolution changed", + raw_response="", + ), + ] + + task = BenchmarkTask( + task_id="test-task", + instruction="Change the display resolution", + domain="desktop", + ) + + result = controller.execute(task, max_steps=10) + + # Verify the correction was used (step was replaced) + step1 = controller.plan_state.steps[0] + assert "third item" in step1.action.lower() or step1.status == "done" + + @patch("openadapt_evals.demo_controller.verify_step") + def test_no_correction_falls_through_to_replan(self, mock_verify, tmp_path): + """Without stored correction, falls through to normal replan.""" + store = CorrectionStore(str(tmp_path / "corrections")) + # Empty store - no corrections + + agent = _make_mock_agent() + adapter = _make_mock_adapter() + + controller = DemoController( + agent=agent, + adapter=adapter, + demo_text=SAMPLE_DEMO, + max_retries=1, + max_replans=1, + correction_store=store, + ) + + # All verifications fail + mock_verify.return_value = VerificationResult( + status="not_verified", + confidence=0.1, + explanation="Wrong state", + raw_response="", + ) + + task = BenchmarkTask( + task_id="other-task", + instruction="Something else", + domain="desktop", + ) + + with patch.object(controller, "_replan") as mock_replan: + # The controller will try correction store (empty), then replan + # We need to handle the replan to avoid infinite loop + def fake_replan(obs, step): + step.status = "failed" + controller._advance() + + mock_replan.side_effect = fake_replan + + result = controller.execute(task, max_steps=5) + + # Replan should have been called since no corrections existed + assert mock_replan.called + + def test_controller_accepts_correction_store_none(self): + """Controller works normally without correction store.""" + agent = _make_mock_agent() + adapter = _make_mock_adapter() + + controller = DemoController( + agent=agent, + adapter=adapter, + demo_text=SAMPLE_DEMO, + correction_store=None, + enable_correction_capture=False, + ) + assert controller.correction_store is None + assert controller.enable_correction_capture is False + + +# --------------------------------------------------------------------------- +# CorrectionCapture tests +# --------------------------------------------------------------------------- + + +class TestCorrectionCapture: + def test_capture_result_structure(self, tmp_path): + from openadapt_evals.correction_capture import CorrectionResult + + result = CorrectionResult( + screenshots=["/tmp/before.png", "/tmp/after.png"], + duration_seconds=5.0, + output_dir=str(tmp_path), + ) + assert len(result.screenshots) == 2 + assert result.duration_seconds == 5.0 + + @patch("openadapt_evals.correction_capture._has_recorder", return_value=False) + @patch("openadapt_evals.correction_capture._take_screenshot") + def test_capture_with_immediate_enter(self, mock_screenshot, mock_has_rec, tmp_path): + """Test capture completes when stdin signals immediately.""" + import io + + from openadapt_evals.correction_capture import CorrectionCapture + + mock_screenshot.return_value = str(tmp_path / "after.png") + + capture = CorrectionCapture(output_dir=str(tmp_path / "capture")) + + before_data = b"fake-png-data" + + # Mock select to signal stdin ready immediately, and provide a + # fake stdin that returns a line (avoids pytest capture conflict) + fake_stdin = io.StringIO("\n") + with patch("select.select", return_value=([fake_stdin], [], [])), patch( + "sys.stdin", fake_stdin + ): + result = capture.capture_correction( + failure_context={ + "screenshot_bytes": before_data, + "step_action": "Click Display", + "explanation": "Wrong button", + }, + timeout_seconds=1, + ) + + assert result.output_dir == str(tmp_path / "capture") + assert result.duration_seconds > 0 + + +# --------------------------------------------------------------------------- +# End-to-end mock test +# --------------------------------------------------------------------------- + + +class TestCorrectionFlywheelE2E: + @patch("openadapt_evals.correction_parser.vlm_call") + @patch("openadapt_evals.demo_controller.verify_step") + def test_full_loop_mock(self, mock_verify, mock_vlm, tmp_path): + """Full loop: fail -> capture correction (mocked) -> store -> retrieve on next run.""" + store = CorrectionStore(str(tmp_path / "corrections")) + task = BenchmarkTask( + task_id="display-resolution", + instruction="Change the display resolution", + domain="desktop", + ) + + # --- Phase 1: Store a correction directly (simulating capture) --- + mock_vlm.return_value = json.dumps( + { + "think": "Display button is third item in sidebar", + "action": "Click the third item labeled Display in sidebar", + "expect": "Display settings pane opens", + } + ) + + from openadapt_evals.correction_parser import parse_correction + + correction_step = parse_correction( + step_action="Click the Display button in the left sidebar.", + failure_explanation="Agent clicked wrong pane", + before_screenshot=b"before", + after_screenshot=b"after", + ) + + store.save( + CorrectionEntry( + task_id="display-resolution", + step_description="Click the Display button in the left sidebar.", + failure_screenshot_path="/tmp/before.png", + failure_explanation="Agent clicked wrong pane", + correction_step=correction_step, + run_id="run-1", + ) + ) + + # --- Phase 2: Second run retrieves correction --- + agent = _make_mock_agent() + adapter = _make_mock_adapter() + + controller = DemoController( + agent=agent, + adapter=adapter, + demo_text=SAMPLE_DEMO, + max_retries=1, + correction_store=store, + ) + + # Step 1: fail first attempt, correction injected, second attempt succeeds + # Step 2: succeeds immediately + mock_verify.side_effect = [ + VerificationResult( + status="not_verified", + confidence=0.2, + explanation="Wrong pane", + raw_response="", + ), + VerificationResult( + status="verified", + confidence=0.95, + explanation="Display settings open", + raw_response="", + ), + VerificationResult( + status="verified", + confidence=0.9, + explanation="Resolution changed", + raw_response="", + ), + ] + + result = controller.execute(task, max_steps=10) + + # The correction should have been used + step1 = controller.plan_state.steps[0] + assert step1.status == "done" + assert "third item" in step1.action.lower() or "display" in step1.action.lower()