diff --git a/START_CLAUDE.bat b/START_CLAUDE.bat new file mode 100644 index 0000000..652333d --- /dev/null +++ b/START_CLAUDE.bat @@ -0,0 +1,89 @@ +@echo off +REM Agent Arena - Claude Agent (Anthropic API) +REM Starts the Python IPC server with Claude as the decision-making LLM +REM Then open Godot, load scenes/foraging.tscn, press F5, then SPACE + +echo ======================================== +echo Agent Arena - Claude Agent (Anthropic) +echo ======================================== +echo. + +REM Check for API key +if "%ANTHROPIC_API_KEY%"=="" ( + echo ERROR: ANTHROPIC_API_KEY is not set! + echo. + echo To set it permanently ^(recommended, only need to do this once^): + echo 1. Start menu ^> search "environment variables" + echo 2. Edit the system environment variables ^> Environment Variables + echo 3. Under User variables ^> New + echo 4. Name: ANTHROPIC_API_KEY Value: sk-ant-... + echo. + echo Or set it for this session only: + echo set ANTHROPIC_API_KEY=sk-ant-... + echo. + echo Get a key at: https://console.anthropic.com + pause + exit /b 1 +) + +cd /d "%~dp0" + +REM Check if .venv exists (project root venv) +if exist ".venv\" ( + echo Activating .venv... + call .venv\Scripts\activate.bat + goto :check_deps +) + +REM Check if python/venv exists (legacy venv) +if exist "python\venv\" ( + echo Activating python\venv... + call python\venv\Scripts\activate.bat + goto :check_deps +) + +echo ERROR: No Python virtual environment found! +echo Please run: python -m venv .venv +echo Then install dependencies: .venv\Scripts\pip install agent-arena-sdk anthropic +pause +exit /b 1 + +:check_deps +REM Check if required packages are installed +python -c "import agent_arena_sdk, anthropic" 2>nul +if errorlevel 1 ( + echo. + echo Installing required packages... + pip install agent-arena-sdk anthropic + if errorlevel 1 ( + echo. + echo Failed to install dependencies. + pause + exit /b 1 + ) +) + +echo. +echo Model : claude-sonnet-4-20250514 (change with --model flag) +echo Server : http://127.0.0.1:5000 +echo Debug : http://127.0.0.1:5000/debug +echo Cost : ~$0.10 per 100-tick run (Sonnet) +echo. +echo Next steps: +echo 1. Open Godot and load scenes/foraging.tscn +echo 2. Press F5 to run the scene +echo 3. Press SPACE to start the simulation +echo. +echo Press Ctrl+C to stop the server +echo ======================================== +echo. + +cd /d "%~dp0\starters\claude" +python run.py --debug %* + +REM If server exits, pause so user can see error +if errorlevel 1 ( + echo. + echo Server exited with error! + pause +) diff --git a/START_IPC_SERVER.bat b/START_IPC_SERVER.bat index 3408347..43b01d0 100644 --- a/START_IPC_SERVER.bat +++ b/START_IPC_SERVER.bat @@ -1,10 +1,11 @@ @echo off -REM Agent Arena - Foraging Demo Startup Script -REM Starts the Python IPC server with the beginner agent (new SDK pattern) -REM Then open Godot, load scenes/foraging.tscn, press F5, then SPACE +REM Agent Arena - IPC Server (Local LLM) +REM This is the original startup script. For clarity, you can also use: +REM START_LOCAL_LLM.bat - Same as this (local llama.cpp model) +REM START_CLAUDE.bat - Uses Anthropic Claude API instead echo ======================================== -echo Agent Arena - Foraging Demo (New SDK) +echo Agent Arena - Local LLM Agent echo ======================================== echo. diff --git a/START_LOCAL_LLM.bat b/START_LOCAL_LLM.bat new file mode 100644 index 0000000..47f9e91 --- /dev/null +++ b/START_LOCAL_LLM.bat @@ -0,0 +1,62 @@ +@echo off +REM Agent Arena - Local LLM Agent +REM Starts the Python IPC server with the local LLM agent (llama.cpp) +REM Then open Godot, load scenes/foraging.tscn, press F5, then SPACE + +echo ======================================== +echo Agent Arena - Local LLM Agent +echo ======================================== +echo. + +cd /d "%~dp0\python" + +REM Check if venv exists +if not exist "venv\" ( + echo ERROR: Python virtual environment not found! + echo Please run: python -m venv venv + echo Then install dependencies: venv\Scripts\pip install -r requirements.txt + pause + exit /b 1 +) + +REM Activate venv +echo Activating Python virtual environment... +call venv\Scripts\activate.bat + +REM Check if required packages are installed +python -c "import fastapi, uvicorn" 2>nul +if errorlevel 1 ( + echo. + echo ERROR: Required packages not installed! + echo Installing dependencies... + pip install fastapi uvicorn + if errorlevel 1 ( + echo. + echo Failed to install dependencies. + pause + exit /b 1 + ) +) + +echo. +echo Starting Local LLM Agent (llama.cpp)... +echo Server will be available at: http://127.0.0.1:5000 +echo Debug inspector at: http://127.0.0.1:5000/debug +echo. +echo Next steps: +echo 1. Open Godot and load scenes/foraging.tscn +echo 2. Press F5 to run the scene +echo 3. Press SPACE to start the simulation +echo. +echo Press Ctrl+C to stop the server +echo ======================================== +echo. + +python run_foraging_demo.py + +REM If server exits, pause so user can see error +if errorlevel 1 ( + echo. + echo Server exited with error! + pause +) diff --git a/python/agent_runtime/__init__.py b/python/agent_runtime/__init__.py index 86f1c09..138d6f9 100644 --- a/python/agent_runtime/__init__.py +++ b/python/agent_runtime/__init__.py @@ -1,14 +1,14 @@ """ -Agent Arena - Agent Runtime Module +Agent Arena - Agent Runtime Module (DEPRECATED) This module provides the core agent runtime infrastructure for LLM-driven agents. -NOTE: After LDX refactor (Issue #60), behavior base classes and memory systems -have been moved to starter templates. Use agent-arena-sdk for new projects. +DEPRECATED: Use agent_arena_sdk for new projects. Shared types (Observation, +EntityInfo, etc.) are re-exported from the SDK. V1-only classes (AgentDecision, +WorldObject, SimpleContext) are still available here. """ from .agent import Agent -from .arena import AgentArena from .reasoning_trace import ( # Backwards compatibility DecisionCapture, @@ -27,17 +27,13 @@ from .runtime import AgentRuntime from .schemas import ( AgentDecision, - Constraint, EntityInfo, - Goal, HazardInfo, ItemInfo, - Metric, MetricDefinition, Objective, Observation, ResourceInfo, - ScenarioDefinition, SimpleContext, ToolSchema, ) @@ -46,7 +42,6 @@ __all__ = [ # Core "Agent", - "AgentArena", "AgentRuntime", "ToolDispatcher", # Tracing (new API) @@ -63,7 +58,7 @@ "InspectorStage", "get_global_inspector", "set_global_inspector", - # Observation/Decision schemas + # Observation/Decision schemas (re-exported from SDK) "Observation", "AgentDecision", "SimpleContext", @@ -72,11 +67,6 @@ "ResourceInfo", "HazardInfo", "ItemInfo", - # Scenario schemas - "ScenarioDefinition", - "Goal", - "Constraint", - "Metric", # Objective system (Issue #60) "Objective", "MetricDefinition", diff --git a/python/agent_runtime/arena.py b/python/agent_runtime/arena.py deleted file mode 100644 index 2aebdaf..0000000 --- a/python/agent_runtime/arena.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -AgentArena - Main orchestrator for running agents in simulation. -""" - -import logging -from typing import TYPE_CHECKING - -from .runtime import AgentRuntime - -if TYPE_CHECKING: - from ipc.server import IPCServer - - from .behavior import AgentBehavior - -logger = logging.getLogger(__name__) - - -class AgentArena: - """ - Main orchestrator connecting user agents to Godot simulation. - - AgentArena provides a simple API for users to: - - Register their agent behaviors - - Connect to the Godot simulation - - Run the agent decision loop - - Example: - from agent_runtime import AgentArena - from my_agents import ForagingAgent - - # Create arena and register agents - arena = AgentArena() - arena.register('agent_001', ForagingAgent()) - - # Connect to simulation and run - arena.connect(host='127.0.0.1', port=5000) - arena.run() # Blocks until simulation ends - - Example (class method): - # Alternatively, connect first then register - arena = AgentArena.connect(host='127.0.0.1', port=5000) - arena.register('agent_001', ForagingAgent()) - arena.run() - """ - - def __init__(self, max_workers: int = 4): - """ - Initialize AgentArena. - - Args: - max_workers: Maximum number of concurrent agent workers - """ - self.runtime = AgentRuntime(max_workers=max_workers) - self.behaviors: dict[str, AgentBehavior] = {} - self.ipc_server: IPCServer | None = None - self._running = False - - logger.info(f"Initialized AgentArena with {max_workers} workers") - - @classmethod - def connect( - cls, host: str = "127.0.0.1", port: int = 5000, max_workers: int = 4 - ) -> "AgentArena": - """ - Create arena and immediately connect to Godot simulation. - - Args: - host: IPC server host address - port: IPC server port - max_workers: Maximum number of concurrent agent workers - - Returns: - Connected AgentArena instance - - Example: - arena = AgentArena.connect() - arena.register('agent_001', MyAgent()) - arena.run() - """ - arena = cls(max_workers=max_workers) - arena._connect(host=host, port=port) - return arena - - def _connect(self, host: str = "127.0.0.1", port: int = 5000) -> None: - """ - Internal method to establish IPC connection. - - Args: - host: IPC server host address - port: IPC server port - """ - from ipc.server import create_server - - logger.info(f"Connecting to Godot simulation at {host}:{port}") - - # Create IPC server with this arena's runtime and behaviors - self.ipc_server = create_server( - runtime=self.runtime, behaviors=self.behaviors, host=host, port=port - ) - - # Create FastAPI app - self.ipc_server.create_app() - - logger.info("Connected to Godot simulation") - - def register(self, agent_id: str, behavior: "AgentBehavior") -> None: - """ - Register an agent behavior for the given ID. - - The behavior will be called each tick to make decisions for this agent. - - Args: - agent_id: Unique identifier for the agent (matches Godot agent ID) - behavior: AgentBehavior instance implementing decide() method - - Example: - arena.register('agent_001', MyForagingAgent()) - arena.register('agent_002', MyCombatAgent()) - """ - if agent_id in self.behaviors: - logger.warning(f"Replacing existing behavior for agent {agent_id}") - - self.behaviors[agent_id] = behavior - logger.info(f"Registered behavior for agent {agent_id}: {type(behavior).__name__}") - - def unregister(self, agent_id: str) -> None: - """ - Unregister an agent. - - Args: - agent_id: ID of agent to unregister - """ - if agent_id in self.behaviors: - del self.behaviors[agent_id] - logger.info(f"Unregistered agent {agent_id}") - else: - logger.warning(f"Cannot unregister agent {agent_id}: not found") - - def run(self) -> None: - """ - Run the main tick loop (blocking). - - This starts the IPC server and blocks until stopped. - The server will handle incoming tick requests from Godot and - call the registered agent behaviors to make decisions. - - Example: - arena = AgentArena.connect() - arena.register('agent_001', MyAgent()) - arena.run() # Blocks here - """ - if not self.ipc_server: - raise RuntimeError( - "Not connected to simulation. Call connect() or use AgentArena.connect() first." - ) - - if not self.behaviors: - logger.warning("No agent behaviors registered. Did you forget to call register()?") - - logger.info(f"Starting arena with {len(self.behaviors)} registered agents") - self._running = True - - try: - # Run the IPC server (blocking) - self.ipc_server.run() - except KeyboardInterrupt: - logger.info("Arena stopped by user") - finally: - self._running = False - self.runtime.stop() - - async def run_async(self) -> None: - """ - Run tick loop in background (async). - - This is an async version of run() that can be awaited. - - Example: - arena = AgentArena.connect() - arena.register('agent_001', MyAgent()) - await arena.run_async() - """ - if not self.ipc_server: - raise RuntimeError( - "Not connected to simulation. Call connect() or use AgentArena.connect() first." - ) - - if not self.behaviors: - logger.warning("No agent behaviors registered. Did you forget to call register()?") - - logger.info(f"Starting arena with {len(self.behaviors)} registered agents") - self._running = True - - try: - await self.ipc_server.run_async() - finally: - self._running = False - self.runtime.stop() - - def stop(self) -> None: - """ - Stop the tick loop. - - This can be called from another thread to stop a running arena. - """ - logger.info("Stopping arena...") - self._running = False - if self.runtime: - self.runtime.stop() - - def is_running(self) -> bool: - """ - Check if arena is currently running. - - Returns: - True if running, False otherwise - """ - return self._running - - def get_registered_agents(self) -> list[str]: - """ - Get list of registered agent IDs. - - Returns: - List of agent IDs - """ - return list(self.behaviors.keys()) - - def get_behavior(self, agent_id: str) -> "AgentBehavior | None": - """ - Get the behavior registered for an agent. - - Args: - agent_id: Agent ID - - Returns: - AgentBehavior instance or None if not found - """ - return self.behaviors.get(agent_id) diff --git a/python/agent_runtime/local_llm_behavior.py b/python/agent_runtime/local_llm_behavior.py deleted file mode 100644 index a0ee715..0000000 --- a/python/agent_runtime/local_llm_behavior.py +++ /dev/null @@ -1,607 +0,0 @@ -""" -LocalLLMBehavior - Bridges local LLM backends to AgentBehavior API. - -This module provides LocalLLMBehavior, a behavior class that wraps local LLM backends -(LlamaCppBackend, VLLMBackend) and implements the AgentBehavior interface, -allowing local GPU-accelerated LLMs to power agents via the IPC server. - -Unlike LLMAgentBehavior which uses external API services, LocalLLMBehavior -uses in-process GPU-accelerated inference via BaseBackend implementations. -""" - -import logging -from typing import TYPE_CHECKING - -from .behavior import AgentBehavior -from .memory import SlidingWindowMemory - -# Backwards compatibility -from .reasoning_trace import ( - PromptInspector, - TraceStepName, - get_global_inspector, -) -from .schemas import AgentDecision, Observation, ToolSchema - -if TYPE_CHECKING: - from backends.base import BaseBackend - -logger = logging.getLogger(__name__) - - -class LocalLLMBehavior(AgentBehavior): - """ - Agent behavior powered by local LLM backends. - - This class wraps local backends (LlamaCppBackend, VLLMBackend) and implements - the AgentBehavior interface. It handles: - - Prompt construction with system prompts and observations - - Memory management with sliding window - - Tool calling via the backend's generate_with_tools() method - - Parsing LLM responses into AgentDecision objects - - Graceful error handling (returns idle on failures) - - Example: - from backends.llama_cpp_backend import LlamaCppBackend - from backends.base import BackendConfig - - config = BackendConfig(model_path="path/to/model.gguf", n_gpu_layers=-1) - backend = LlamaCppBackend(config) - - behavior = LocalLLMBehavior( - backend=backend, - system_prompt="You are a foraging agent. Collect resources and avoid hazards.", - memory_capacity=10 - ) - - # Register with arena - arena.register("agent_001", behavior) - """ - - def __init__( - self, - backend: "BaseBackend", - system_prompt: str = "You are an autonomous agent in a simulation environment.", - memory_capacity: int = 10, - temperature: float = 0.7, - max_tokens: int = 256, - inspector: PromptInspector | None = None, - ): - """ - Initialize the local LLM behavior. - - Args: - backend: Local LLM backend (LlamaCppBackend or VLLMBackend) - system_prompt: System prompt describing the agent's role and task - memory_capacity: Number of recent observations to keep in memory - temperature: Temperature for generation (0-1) - max_tokens: Maximum tokens to generate per decision - inspector: Optional PromptInspector for debugging (uses global if None) - """ - self.backend = backend - self.system_prompt = system_prompt - self.memory = SlidingWindowMemory(capacity=memory_capacity) - self.temperature = temperature - self.max_tokens = max_tokens - self.inspector = inspector or get_global_inspector() - - # Validate backend is available - if not self.backend.is_available(): - raise RuntimeError("Backend is not available - model may not be loaded") - - logger.info( - f"Initialized LocalLLMBehavior with {type(backend).__name__} " - f"(memory_capacity={memory_capacity})" - ) - - def decide(self, observation: Observation, tools: list[ToolSchema]) -> AgentDecision: - """ - Decide what action to take given the current observation. - - Constructs a prompt from the system prompt, memory, and current observation, - then calls the backend's generate_with_tools() method to get a decision. - - If tracing is enabled (via enable_tracing()), each step is automatically logged: - - observation: The input observation - - prompt: The constructed prompt - - llm_response: The raw LLM response - - decision: The parsed decision - - Args: - observation: Current tick's observation from Godot - tools: List of available tools with their schemas - - Returns: - AgentDecision specifying which tool to call and with what parameters - """ - # Start trace capture - capture = self.inspector.start_capture(observation.agent_id, observation.tick) - self._current_trace = capture # Enable log_step() calls - - try: - # Store observation in memory - self.memory.store(observation) - - # Capture observation stage - if capture: - capture.add_step( - TraceStepName.OBSERVATION, - { - "agent_id": observation.agent_id, - "tick": observation.tick, - "position": observation.position, - "health": observation.health, - "energy": observation.energy, - "nearby_resources": [ - { - "name": r.name, - "type": r.type, - "distance": r.distance, - "position": r.position, - } - for r in observation.nearby_resources - ], - "nearby_hazards": [ - { - "name": h.name, - "type": h.type, - "distance": h.distance, - "damage": h.damage, - } - for h in observation.nearby_hazards - ], - "inventory": [ - {"name": item.name, "quantity": item.quantity} - for item in observation.inventory - ], - }, - ) - - # Build prompt from system prompt, memory, and observation - prompt = self._build_prompt(observation, tools) - logger.debug(f"Prompt length: {len(prompt)} chars (~{len(prompt) // 4} tokens)") - - # Log prompt (if tracing enabled) - self.log_step("prompt", {"text": prompt}) - - # Capture prompt building stage - if capture: - memory_items = self.memory.retrieve(limit=5) - capture.add_step( - TraceStepName.PROMPT_BUILDING, - { - "system_prompt": self.system_prompt, - "memory_context": { - "count": len(memory_items), - "items": [ - {"tick": obs.tick, "position": obs.position} for obs in memory_items - ], - }, - "final_prompt": prompt, - "prompt_length": len(prompt), - "estimated_tokens": len(prompt) // 4, - }, - ) - - # Convert tools to dict format for backend - tool_dicts = [ - { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - for tool in tools - ] - - # Capture LLM request stage - if capture: - capture.add_step( - TraceStepName.LLM_REQUEST, - { - "model": getattr(self.backend, "model_name", "unknown"), - "prompt": prompt, - "tools": tool_dicts, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - }, - ) - - # Generate response using backend - import time - - start_time = time.time() - logger.debug(f"Generating decision for agent {observation.agent_id}") - result = self.backend.generate_with_tools( - prompt=prompt, tools=tool_dicts, temperature=self.temperature - ) - elapsed_ms = (time.time() - start_time) * 1000 - - # Capture LLM response stage - if capture: - capture.add_step( - TraceStepName.LLM_RESPONSE, - { - "raw_text": result.text, - "tokens_used": result.tokens_used, - "finish_reason": result.finish_reason, - "metadata": result.metadata, - "latency_ms": elapsed_ms, - }, - ) - - # Check for generation errors - if result.finish_reason == "error": - error_msg = result.metadata.get("error", "Unknown error") - logger.error(f"LLM generation error: {error_msg}") - decision = AgentDecision.idle(reasoning=f"LLM error: {error_msg}") - - # Capture error decision - if capture: - capture.add_step( - TraceStepName.DECISION, - { - "tool": decision.tool, - "params": decision.params, - "reasoning": decision.reasoning, - "total_latency_ms": elapsed_ms, - "error": error_msg, - }, - ) - - self.inspector.finish_capture(observation.agent_id, observation.tick) - self._current_trace = None # Clear for next decision - return decision - - # Debug: log token usage and raw response - logger.debug(f"Tokens used: {result.tokens_used}") - logger.debug(f"Raw LLM response: {result.text[:500] if result.text else '(empty)'}") - - # Log LLM response (if tracing enabled) - self.log_step( - "llm_response", - { - "text": result.text, - "tokens_used": result.tokens_used, - "finish_reason": result.finish_reason, - "elapsed_ms": elapsed_ms, - }, - ) - - # Try to parse from metadata first (pre-parsed by backend) - if "parsed_tool_call" in result.metadata: - parsed = result.metadata["parsed_tool_call"] - decision = AgentDecision( - tool=parsed.get("tool", "idle"), - params=parsed.get("params", {}), - reasoning=parsed.get("reasoning", ""), - ) - elif "tool_call" in result.metadata: - # Native tool call from vLLM - tool_call = result.metadata["tool_call"] - decision = AgentDecision( - tool=tool_call["name"], - params=tool_call["arguments"], - reasoning=result.text or "LLM tool call", - ) - else: - # Parse the raw text response using robust JSON extraction - try: - decision = AgentDecision.from_llm_response(result.text) - except ValueError as e: - logger.warning(f"Failed to parse LLM response: {e}") - logger.debug(f"Raw response: {result.text}") - decision = AgentDecision.idle(reasoning=f"Parse error: {e}") - - # Capture final decision stage - if capture: - capture.add_step( - TraceStepName.DECISION, - { - "tool": decision.tool, - "params": decision.params, - "reasoning": decision.reasoning, - "total_latency_ms": elapsed_ms, - }, - ) - - logger.info( - f"Agent {observation.agent_id} decided: {decision.tool} - {decision.reasoning} " - f"(LLM took {elapsed_ms:.0f}ms, {result.tokens_used} tokens)" - ) - - # Finish capture and optionally write to file - self.inspector.finish_capture(observation.agent_id, observation.tick) - self._current_trace = None # Clear for next decision - - return decision - - except Exception as e: - logger.error(f"Error in LocalLLMBehavior.decide(): {e}", exc_info=True) - - # Capture error in inspector - if capture: - capture.add_step( - TraceStepName.DECISION, - {"tool": "idle", "params": {}, "reasoning": f"Error: {e}", "error": str(e)}, - ) - self.inspector.finish_capture(observation.agent_id, observation.tick) - self._current_trace = None # Clear for next decision - - return AgentDecision.idle(reasoning=f"Error: {e}") - - def _build_prompt(self, observation: Observation, tools: list[ToolSchema]) -> str: - """ - Build the prompt for LLM generation. - - Includes system prompt, memory context, current observation, and available tools. - - Args: - observation: Current observation - tools: Available tools - - Returns: - Formatted prompt string - """ - sections = [] - - # Add system prompt - if self.system_prompt: - sections.append(self.system_prompt) - sections.append("") - - # Add memory context if available - memory_items = self.memory.retrieve(limit=5) - if memory_items and len(memory_items) > 1: # Don't include if only current observation - sections.append("## Recent History") - # Memory items are returned most recent first, so reverse and skip last (current) - for i, obs in enumerate(reversed(memory_items[1:]), 1): - sections.append(f" {i}. Tick {obs.tick}: Position {obs.position}") - if obs.nearby_resources: - sections.append(f" Resources nearby: {len(obs.nearby_resources)}") - if obs.nearby_hazards: - sections.append(f" Hazards nearby: {len(obs.nearby_hazards)}") - sections.append("") - - # Current state - sections.append("## Current Situation") - sections.append(f"Position: {observation.position}") - sections.append(f"Health: {observation.health}") - sections.append(f"Energy: {observation.energy}") - sections.append(f"Tick: {observation.tick}") - - # Add analysis hints (but let LLM reason) - danger_hazards = [h for h in observation.nearby_hazards if h.distance < 2.0] - collect_resources = [r for r in observation.nearby_resources if r.distance < 1.0] - - sections.append("\n## Quick Analysis") - sections.append(f"Hazards in danger zone (< 2.0): {len(danger_hazards)}") - sections.append(f"Resources in collection range (< 1.0): {len(collect_resources)}") - - # Nearby resources - if observation.nearby_resources: - sections.append("\n## Nearby Resources") - for resource in observation.nearby_resources[:5]: # Limit to 5 for brevity - sections.append( - f"- {resource.name} ({resource.type}) at distance {resource.distance:.1f}, " - f"position {resource.position}" - ) - else: - sections.append("\n## Nearby Resources\nNone visible") - - # Nearby hazards - if observation.nearby_hazards: - sections.append("\n## Nearby Hazards") - for hazard in observation.nearby_hazards[:5]: # Limit to 5 for brevity - damage_str = f", damage: {hazard.damage}" if hazard.damage > 0 else "" - sections.append( - f"- {hazard.name} ({hazard.type}) at distance {hazard.distance:.1f}{damage_str}, " - f"position {hazard.position}" - ) - else: - sections.append("\n## Nearby Hazards\nNone visible") - - # Remembered objects from world_map (out of sight but known) - visible_names = {r.name for r in observation.nearby_resources} - visible_names.update(h.name for h in observation.nearby_hazards) - - remembered_resources = [ - obj for obj in self.world_map.get_resources() if obj.name not in visible_names - ] - remembered_hazards = [ - obj for obj in self.world_map.get_hazards() if obj.name not in visible_names - ] - - if remembered_resources or remembered_hazards: - sections.append("\n## Remembered Objects (out of sight)") - if remembered_resources: - # Sort by distance from current position - remembered_resources.sort(key=lambda obj: obj.distance_to(observation.position)) - sections.append("Resources:") - for obj in remembered_resources[:5]: - dist = obj.distance_to(observation.position) - stale = observation.tick - obj.last_seen_tick - sections.append( - f"- {obj.name} ({obj.subtype}) at position {obj.position}, " - f"~{dist:.1f} units away (last seen {stale} ticks ago)" - ) - if remembered_hazards: - remembered_hazards.sort(key=lambda obj: obj.distance_to(observation.position)) - sections.append("Hazards:") - for obj in remembered_hazards[:3]: - dist = obj.distance_to(observation.position) - sections.append( - f"- {obj.name} ({obj.subtype}) at position {obj.position}, " - f"~{dist:.1f} units away, damage: {obj.damage}" - ) - - # Recent Experiences (collisions, damage, traps) - # Note: Always access world_map - it's lazy-loaded and always returns a valid instance. - # Don't use truthiness check as SpatialMemory.__len__ returns object count which may be 0. - experiences = self.world_map.get_recent_experiences(limit=5) - if experiences: - sections.append("\n## Recent Experiences") - for exp in experiences: - if exp.event_type == "collision": - sections.append( - f"- Tick {exp.tick}: Movement BLOCKED by {exp.object_name} " - f"at position {exp.position}" - ) - elif exp.event_type == "damage": - sections.append( - f"- Tick {exp.tick}: Took {exp.damage_taken:.1f} damage from " - f"{exp.object_name} at {exp.position}" - ) - elif exp.event_type == "trapped": - ticks_trapped = exp.metadata.get("trap_duration", "?") - sections.append( - f"- Tick {exp.tick}: TRAPPED by {exp.object_name}! " - f"Lost {ticks_trapped} ticks" - ) - - # Known Obstacles from collisions - obstacles = self.world_map.query_by_type("obstacle") - if obstacles: - sections.append("\n## Known Obstacles (from collisions)") - for obstacle in obstacles[:5]: - sections.append(f"- {obstacle.name} at {obstacle.position}") - - # Exploration Status - if observation.exploration: - exploration = observation.exploration - sections.append("\n## Exploration Status") - sections.append(f"Area explored: {exploration.exploration_percentage:.1f}%") - - if exploration.frontiers_by_direction: - sections.append("Unexplored frontiers:") - # Sort by distance (nearest first) - sorted_frontiers = sorted( - exploration.frontiers_by_direction.items(), key=lambda x: x[1] - ) - for direction, distance in sorted_frontiers[:5]: - sections.append(f"- {direction.upper()}: {distance:.1f} units away") - - if exploration.explore_targets: - sections.append("\nSuggested exploration targets:") - for target in exploration.explore_targets[:3]: - sections.append( - f"- {target.direction.upper()}: position {target.position}, " - f"{target.distance:.1f} units away" - ) - - # Inventory - if observation.inventory: - sections.append("\n## Inventory") - for item in observation.inventory: - sections.append(f"- {item.name} x{item.quantity}") - else: - sections.append("\n## Inventory\nEmpty") - - # Add instruction for response format (CoT) - sections.append("\n## Your Task") - sections.append( - "Follow the RESPONSE FORMAT from your instructions. " - "Show your THINKING step by step, then output your ACTION as JSON." - ) - - return "\n".join(sections) - - def on_tool_result(self, tool: str, result: dict) -> None: - """ - Called after a tool execution completes. - - Can be overridden to update memory or adjust strategy based on tool results. - - Args: - tool: Name of the tool that was executed - result: Result dictionary from the tool - """ - logger.debug(f"Tool '{tool}' executed with result: {result}") - - def on_episode_start(self) -> None: - """ - Called when a new episode begins. - - Clears memory to start fresh. If tracing is enabled, starts a new trace episode. - """ - super().on_episode_start() # Handle trace episode start and world_map clearing - logger.info("Episode started, clearing memory") - self.memory.clear() - # Note: world_map clearing is handled by super().on_episode_start() - # which checks _world_map is not None (avoiding SpatialMemory truthiness issue) - - def on_episode_end(self, success: bool, metrics: dict | None = None) -> None: - """ - Called when an episode ends. - - Args: - success: Whether the episode goal was achieved - metrics: Optional metrics from the scenario - """ - logger.info(f"Episode ended: success={success}, observations_stored={len(self.memory)}") - if metrics: - logger.info(f"Episode metrics: {metrics}") - super().on_episode_end(success, metrics) # Handle trace episode end - - -def create_local_llm_behavior( - model_path: str, - system_prompt: str = "", - n_gpu_layers: int = -1, - temperature: float = 0.7, - max_tokens: int = 256, - memory_capacity: int = 10, -) -> LocalLLMBehavior: - """ - Factory function to create a LocalLLMBehavior with LlamaCppBackend. - - This is a convenience function that handles backend creation. - For more control, create the backend manually and pass it to LocalLLMBehavior. - - Args: - model_path: Path to the GGUF model file - system_prompt: System prompt for the agent - n_gpu_layers: GPU layers to offload (-1 = all, 0 = CPU only) - temperature: LLM temperature (0-1) - max_tokens: Maximum tokens per response - memory_capacity: Number of recent observations to keep in memory - - Returns: - Configured LocalLLMBehavior instance - - Example: - behavior = create_local_llm_behavior( - model_path="models/mistral-7b.gguf", - system_prompt="You are a foraging agent.", - n_gpu_layers=-1 - ) - """ - from backends import BackendConfig, LlamaCppBackend - - # Use default foraging prompt if none provided - if not system_prompt: - system_prompt = """You are an autonomous foraging agent in a simulation environment. - -Your goal is to: -1. Collect resources (like apples) to increase your score -2. Avoid hazards (like fire) that can damage you -3. Manage your health and energy efficiently - -When you receive an observation, analyze the situation and choose the best action. -Prioritize safety (avoid hazards) over collecting resources.""" - - config = BackendConfig( - model_path=model_path, - temperature=temperature, - max_tokens=max_tokens, - n_gpu_layers=n_gpu_layers, - ) - - backend = LlamaCppBackend(config) - - return LocalLLMBehavior( - backend=backend, - system_prompt=system_prompt, - memory_capacity=memory_capacity, - temperature=temperature, - max_tokens=max_tokens, - ) diff --git a/python/agent_runtime/schemas.py b/python/agent_runtime/schemas.py index 4efc427..0cbde87 100644 --- a/python/agent_runtime/schemas.py +++ b/python/agent_runtime/schemas.py @@ -1,16 +1,41 @@ """ -Core data schemas for Agent Arena. +Agent Runtime schemas (DEPRECATED — use agent_arena_sdk for new projects). -Defines the contracts between framework components, user code, and Godot. +Shared types (Observation, EntityInfo, etc.) are re-exported from the SDK, +which is the single source of truth. V1-only classes that still have +internal consumers are kept here. """ import json import logging from dataclasses import dataclass, field +# --------------------------------------------------------------------------- +# Re-exports from SDK (single source of truth) +# --------------------------------------------------------------------------- +from agent_arena_sdk.schemas import ( # noqa: F401 + EntityInfo, + ExplorationInfo, + ExploreTarget, + HazardInfo, + ItemInfo, + MetricDefinition, + Objective, + Observation, + ResourceInfo, + StationInfo, + ToolResult, + ToolSchema, +) + logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# V1-only classes (still used by agent_runtime internals) +# --------------------------------------------------------------------------- + + @dataclass class WorldObject: """A remembered object in the world map. @@ -143,569 +168,6 @@ def from_dict(cls, data: dict) -> "ExperienceEvent": ) -@dataclass -class EntityInfo: - """Information about a visible entity.""" - - id: str - type: str - position: tuple[float, float, float] - distance: float - metadata: dict = field(default_factory=dict) - - -@dataclass -class ResourceInfo: - """Information about a nearby resource.""" - - name: str - type: str - position: tuple[float, float, float] - distance: float - - -@dataclass -class HazardInfo: - """Information about a nearby hazard.""" - - name: str - type: str - position: tuple[float, float, float] - distance: float - damage: float = 0.0 - - -@dataclass -class ExploreTarget: - """A potential exploration target.""" - - direction: str # "north", "south", "east", "west", etc. - distance: float - position: tuple[float, float, float] - - -@dataclass -class ExplorationInfo: - """Information about world exploration status. - - Tracks what percentage of the world the agent has seen and - provides information about unexplored frontiers. - """ - - exploration_percentage: float # 0-100 - total_cells: int - seen_cells: int - frontiers_by_direction: dict[str, float] # direction -> distance to nearest frontier - explore_targets: list[ExploreTarget] = field(default_factory=list) - - @classmethod - def from_dict(cls, data: dict) -> "ExplorationInfo": - """Create from dictionary.""" - targets = [] - for t in data.get("explore_targets", []): - targets.append( - ExploreTarget( - direction=t["direction"], - distance=t["distance"], - position=tuple(t["position"]), - ) - ) - return cls( - exploration_percentage=data.get("exploration_percentage", 0.0), - total_cells=data.get("total_cells", 0), - seen_cells=data.get("seen_cells", 0), - frontiers_by_direction=data.get("frontiers_by_direction", {}), - explore_targets=targets, - ) - - def to_dict(self) -> dict: - """Convert to dictionary.""" - return { - "exploration_percentage": self.exploration_percentage, - "total_cells": self.total_cells, - "seen_cells": self.seen_cells, - "frontiers_by_direction": self.frontiers_by_direction, - "explore_targets": [ - { - "direction": t.direction, - "distance": t.distance, - "position": list(t.position), - } - for t in self.explore_targets - ], - } - - -@dataclass -class StationInfo: - """Information about a nearby crafting station.""" - - name: str - type: str # "workbench", "anvil", "furnace" - position: tuple[float, float, float] - distance: float - - -@dataclass -class ToolResult: - """Result of a tool execution from the previous tick. - - Sent from Godot in the observation payload so the agent knows - whether its last action succeeded or failed. - - Attributes: - tool: Name of the tool that was executed (e.g., "move_to", "collect") - success: Whether the tool execution succeeded - result: Full result dictionary from the tool - error: Error message if the tool failed (empty string if success) - duration_ticks: How many simulation ticks the tool took to complete - """ - - tool: str - success: bool - result: dict = field(default_factory=dict) - error: str = "" - duration_ticks: int = 0 - - @classmethod - def from_dict(cls, data: dict) -> "ToolResult": - """Create from dictionary.""" - return cls( - tool=data["tool"], - success=data.get("success", False), - result=data.get("result", {}), - error=data.get("error", ""), - duration_ticks=data.get("duration_ticks", 0), - ) - - def to_dict(self) -> dict: - """Convert to dictionary.""" - return { - "tool": self.tool, - "success": self.success, - "result": self.result, - "error": self.error, - "duration_ticks": self.duration_ticks, - } - - -@dataclass -class ItemInfo: - """Information about an inventory item.""" - - id: str - name: str - quantity: int = 1 - - -@dataclass -class ToolSchema: - """Schema for an available tool.""" - - name: str - description: str - parameters: dict # JSON Schema format - - def to_openai_format(self) -> dict: - """ - Convert to OpenAI function calling format. - - Returns: - Dictionary in OpenAI function calling format - """ - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, - } - - def to_anthropic_format(self) -> dict: - """ - Convert to Anthropic tool calling format. - - Returns: - Dictionary in Anthropic tool calling format - """ - return { - "name": self.name, - "description": self.description, - "input_schema": self.parameters, - } - - @classmethod - def from_dict(cls, data: dict) -> "ToolSchema": - """ - Create ToolSchema from dictionary. - - Args: - data: Dictionary with 'name', 'description', and 'parameters' - - Returns: - ToolSchema instance - """ - return cls( - name=data["name"], - description=data["description"], - parameters=data["parameters"], - ) - - -@dataclass -class MetricDefinition: - """ - Definition of a success metric for an objective. - - Part of the objective system (Issue #60 LDX refactor). - - Attributes: - target: The target value to achieve - weight: How important this metric is (default 1.0) - lower_is_better: Whether lower values are better (e.g., time_taken) - required: Whether this metric must be met to succeed - """ - - target: float - weight: float = 1.0 - lower_is_better: bool = False - required: bool = False - - @classmethod - def from_dict(cls, data: dict) -> "MetricDefinition": - """Create MetricDefinition from dictionary.""" - return cls( - target=data["target"], - weight=data.get("weight", 1.0), - lower_is_better=data.get("lower_is_better", False), - required=data.get("required", False), - ) - - def to_dict(self) -> dict: - """Convert to dictionary for serialization.""" - return { - "target": self.target, - "weight": self.weight, - "lower_is_better": self.lower_is_better, - "required": self.required, - } - - -@dataclass -class Objective: - """ - Scenario-defined goals for the agent. - - Part of the objective system (Issue #60 LDX refactor). - Objectives are passed from the game scenario to the agent via observations. - This enables general-purpose agents that adapt to different goals. - - Attributes: - description: Human-readable description of the objective - success_metrics: Dictionary of metric names to their definitions - time_limit: Time limit in ticks (0 = unlimited) - - Example: - objective = Objective( - description="Collect resources while avoiding hazards", - success_metrics={ - "resources_collected": MetricDefinition(target=10, weight=1.0), - "health_remaining": MetricDefinition(target=50, weight=0.5) - }, - time_limit=600 - ) - """ - - description: str - success_metrics: dict[str, MetricDefinition] = field(default_factory=dict) - time_limit: int = 0 # 0 = unlimited - - @classmethod - def from_dict(cls, data: dict) -> "Objective": - """ - Create Objective from dictionary. - - Args: - data: Dictionary from IPC message - - Returns: - Objective instance - """ - success_metrics = {} - for name, metric_data in data.get("success_metrics", {}).items(): - success_metrics[name] = MetricDefinition.from_dict(metric_data) - - return cls( - description=data["description"], - success_metrics=success_metrics, - time_limit=data.get("time_limit", 0), - ) - - def to_dict(self) -> dict: - """ - Convert to dictionary for serialization. - - Returns: - Dictionary representation - """ - return { - "description": self.description, - "success_metrics": { - name: metric.to_dict() for name, metric in self.success_metrics.items() - }, - "time_limit": self.time_limit, - } - - -@dataclass -class Observation: - """What the agent receives from Godot each tick.""" - - agent_id: str - tick: int - position: tuple[float, float, float] - rotation: tuple[float, float, float] | None = None - velocity: tuple[float, float, float] | None = None - visible_entities: list[EntityInfo] = field(default_factory=list) - nearby_resources: list[ResourceInfo] = field(default_factory=list) - nearby_hazards: list[HazardInfo] = field(default_factory=list) - nearby_stations: list[StationInfo] = field(default_factory=list) - inventory: list[ItemInfo] = field(default_factory=list) - health: float = 100.0 - energy: float = 100.0 - perception_radius: float = 50.0 # Max perception distance (default for backward compat) - exploration: ExplorationInfo | None = None # World exploration status - # Objective system fields (NEW - Issue #60 LDX refactor) - scenario_name: str = "" - objective: Objective | None = None - current_progress: dict[str, float] = field(default_factory=dict) - custom: dict = field(default_factory=dict) - # Tool result feedback (Issue #71) - last_tool_result: ToolResult | None = None - - @classmethod - def from_dict(cls, data: dict) -> "Observation": - """ - Create Observation from IPC dictionary. - - Args: - data: Dictionary from Godot IPC - - Returns: - Observation instance - """ - # Parse position (required) - position = ( - tuple(data["position"]) if isinstance(data["position"], list) else data["position"] - ) - - # Parse optional rotation - rotation = None - if "rotation" in data and data["rotation"] is not None: - rotation = ( - tuple(data["rotation"]) if isinstance(data["rotation"], list) else data["rotation"] - ) - - # Parse optional velocity - velocity = None - if "velocity" in data and data["velocity"] is not None: - velocity = ( - tuple(data["velocity"]) if isinstance(data["velocity"], list) else data["velocity"] - ) - - # Parse visible entities - visible_entities = [] - for entity_data in data.get("visible_entities", []): - entity_pos = ( - tuple(entity_data["position"]) - if isinstance(entity_data["position"], list) - else entity_data["position"] - ) - visible_entities.append( - EntityInfo( - id=entity_data["id"], - type=entity_data["type"], - position=entity_pos, - distance=entity_data["distance"], - metadata=entity_data.get("metadata", {}), - ) - ) - - # Parse nearby resources - nearby_resources = [] - for resource_data in data.get("nearby_resources", []): - resource_pos = ( - tuple(resource_data["position"]) - if isinstance(resource_data["position"], list) - else resource_data["position"] - ) - nearby_resources.append( - ResourceInfo( - name=resource_data["name"], - type=resource_data["type"], - position=resource_pos, - distance=resource_data["distance"], - ) - ) - - # Parse nearby hazards - nearby_hazards = [] - for hazard_data in data.get("nearby_hazards", []): - hazard_pos = ( - tuple(hazard_data["position"]) - if isinstance(hazard_data["position"], list) - else hazard_data["position"] - ) - nearby_hazards.append( - HazardInfo( - name=hazard_data["name"], - type=hazard_data["type"], - position=hazard_pos, - distance=hazard_data["distance"], - damage=hazard_data.get("damage", 0.0), - ) - ) - - # Parse nearby stations - nearby_stations = [] - for station_data in data.get("nearby_stations", []): - station_pos = ( - tuple(station_data["position"]) - if isinstance(station_data["position"], list) - else station_data["position"] - ) - nearby_stations.append( - StationInfo( - name=station_data["name"], - type=station_data["type"], - position=station_pos, - distance=station_data["distance"], - ) - ) - - # Parse inventory (list[ItemInfo] format; dict format goes to custom) - inventory = [] - raw_inventory = data.get("inventory", []) - if isinstance(raw_inventory, list): - for item_data in raw_inventory: - if isinstance(item_data, dict) and "id" in item_data: - inventory.append( - ItemInfo( - id=item_data["id"], - name=item_data["name"], - quantity=item_data.get("quantity", 1), - ) - ) - - # Parse exploration data - exploration = None - if "exploration" in data and data["exploration"]: - exploration = ExplorationInfo.from_dict(data["exploration"]) - - # Parse objective (NEW - Issue #60) - objective = None - if "objective" in data and data["objective"]: - objective = Objective.from_dict(data["objective"]) - - # Parse tool result from last action (Issue #71) - last_tool_result = None - if "tool_result" in data and data["tool_result"]: - last_tool_result = ToolResult.from_dict(data["tool_result"]) - - return cls( - agent_id=data["agent_id"], - tick=data["tick"], - position=position, - rotation=rotation, - velocity=velocity, - visible_entities=visible_entities, - nearby_resources=nearby_resources, - nearby_hazards=nearby_hazards, - nearby_stations=nearby_stations, - inventory=inventory, - health=data.get("health", 100.0), - energy=data.get("energy", 100.0), - perception_radius=data.get("perception_radius", 50.0), - exploration=exploration, - scenario_name=data.get("scenario_name", ""), - objective=objective, - current_progress=data.get("current_progress", {}), - custom=data.get("custom", {}), - last_tool_result=last_tool_result, - ) - - def to_dict(self) -> dict: - """ - Convert to dictionary for serialization. - - Returns: - Dictionary representation - """ - return { - "agent_id": self.agent_id, - "tick": self.tick, - "position": list(self.position), - "rotation": list(self.rotation) if self.rotation else None, - "velocity": list(self.velocity) if self.velocity else None, - "visible_entities": [ - { - "id": e.id, - "type": e.type, - "position": list(e.position), - "distance": e.distance, - "metadata": e.metadata, - } - for e in self.visible_entities - ], - "nearby_resources": [ - { - "name": r.name, - "type": r.type, - "position": list(r.position), - "distance": r.distance, - } - for r in self.nearby_resources - ], - "nearby_hazards": [ - { - "name": h.name, - "type": h.type, - "position": list(h.position), - "distance": h.distance, - "damage": h.damage, - } - for h in self.nearby_hazards - ], - "nearby_stations": [ - { - "name": s.name, - "type": s.type, - "position": list(s.position), - "distance": s.distance, - } - for s in self.nearby_stations - ], - "inventory": [ - { - "id": i.id, - "name": i.name, - "quantity": i.quantity, - } - for i in self.inventory - ], - "health": self.health, - "energy": self.energy, - "perception_radius": self.perception_radius, - "exploration": self.exploration.to_dict() if self.exploration else None, - "scenario_name": self.scenario_name, - "objective": self.objective.to_dict() if self.objective else None, - "current_progress": self.current_progress, - "custom": self.custom, - "tool_result": self.last_tool_result.to_dict() if self.last_tool_result else None, - } - - @dataclass class AgentDecision: """What the agent returns to the framework.""" @@ -1000,370 +462,3 @@ def from_observation(cls, obs: Observation, goal: str | None = None) -> "SimpleC goal=goal, tick=obs.tick, ) - - -# ============================================================================= -# Scenario Definition Schemas -# ============================================================================= - - -@dataclass -class Goal: - """A scenario goal that the agent should achieve.""" - - name: str - description: str - success_condition: str # Human-readable description of success - priority: int = 1 # Lower = higher priority - optional: bool = False - - -@dataclass -class Constraint: - """A constraint or rule the agent must follow.""" - - name: str - description: str - penalty: str | None = None # What happens if violated - - -@dataclass -class Metric: - """A metric used to evaluate agent performance.""" - - name: str - description: str - unit: str | None = None - optimize: str = "maximize" # "maximize", "minimize", or "target" - target_value: float | None = None # For "target" optimization - - -@dataclass -class ScenarioDefinition: - """ - Complete definition of a scenario for both LLM agents and documentation. - - This is the single source of truth for scenario information: - - LLM agents use it to understand their task - - Documentation is auto-generated from it - - Framework validates against it - """ - - # Identity (required fields first - no defaults) - name: str - id: str # Machine-readable identifier (e.g., "foraging") - tier: int # Learning tier: 1=beginner, 2=intermediate, 3=advanced - description: str - - # Optional fields with defaults - version: str = "1.0.0" - backstory: str | None = None # Optional narrative context - - # Goals and constraints - goals: list[Goal] = field(default_factory=list) - constraints: list[Constraint] = field(default_factory=list) - - # Available tools (references to tool names, full schemas loaded separately) - available_tools: list[str] = field(default_factory=list) - - # Success metrics - metrics: list[Metric] = field(default_factory=list) - success_threshold: dict = field(default_factory=dict) # Metric name -> value - - # Perception info - perception_info: dict = field(default_factory=dict) # What agent can observe - - # Hints for learners (not sent to LLM by default) - hints: list[str] = field(default_factory=list) - learning_objectives: list[str] = field(default_factory=list) - - # Resource types in this scenario - resource_types: list[dict] = field(default_factory=list) - hazard_types: list[dict] = field(default_factory=list) - - def to_system_prompt(self, include_hints: bool = False) -> str: - """ - Generate a system prompt section for LLM agents. - - Args: - include_hints: Whether to include learner hints - - Returns: - Formatted string for system prompt - """ - sections = [] - - # Scenario overview - sections.append(f"# Scenario: {self.name}\n") - sections.append(self.description) - if self.backstory: - sections.append(f"\n{self.backstory}") - - # Goals - sections.append("\n## Goals") - for goal in sorted(self.goals, key=lambda g: g.priority): - optional_tag = " (optional)" if goal.optional else "" - sections.append(f"- **{goal.name}**{optional_tag}: {goal.description}") - sections.append(f" - Success: {goal.success_condition}") - - # Constraints - if self.constraints: - sections.append("\n## Constraints") - for constraint in self.constraints: - sections.append(f"- **{constraint.name}**: {constraint.description}") - if constraint.penalty: - sections.append(f" - Penalty: {constraint.penalty}") - - # Available tools - if self.available_tools: - sections.append("\n## Available Tools") - for tool_name in self.available_tools: - sections.append(f"- `{tool_name}`") - - # Perception info - if self.perception_info: - sections.append("\n## Perception") - for key, value in self.perception_info.items(): - sections.append(f"- **{key}**: {value}") - - # Resource types - if self.resource_types: - sections.append("\n## Resource Types") - for rt in self.resource_types: - sections.append(f"- **{rt['name']}** ({rt['type']}): {rt.get('description', '')}") - - # Hazard types - if self.hazard_types: - sections.append("\n## Hazard Types") - for ht in self.hazard_types: - sections.append( - f"- **{ht['name']}** ({ht['type']}): {ht.get('description', '')} " - f"[Damage: {ht.get('damage', 'unknown')}]" - ) - - # Success metrics - if self.metrics: - sections.append("\n## Success Metrics") - for metric in self.metrics: - unit_str = f" ({metric.unit})" if metric.unit else "" - sections.append(f"- **{metric.name}**{unit_str}: {metric.description}") - if metric.optimize == "target" and metric.target_value is not None: - sections.append(f" - Target: {metric.target_value}") - else: - sections.append(f" - Goal: {metric.optimize}") - - # Optional hints - if include_hints and self.hints: - sections.append("\n## Hints") - for hint in self.hints: - sections.append(f"- {hint}") - - return "\n".join(sections) - - def to_markdown(self) -> str: - """ - Generate full markdown documentation for learners. - - Returns: - Complete markdown document - """ - sections = [] - - # Header - sections.append(f"# {self.name}") - sections.append( - f"\n**Tier:** {self.tier} | **ID:** `{self.id}` | **Version:** {self.version}\n" - ) - - # Description - sections.append("## Overview\n") - sections.append(self.description) - if self.backstory: - sections.append(f"\n> {self.backstory}") - - # Learning objectives - if self.learning_objectives: - sections.append("\n## Learning Objectives\n") - sections.append("After completing this scenario, you will understand:\n") - for obj in self.learning_objectives: - sections.append(f"- {obj}") - - # Goals - sections.append("\n## Goals\n") - for goal in sorted(self.goals, key=lambda g: g.priority): - priority_badge = f"[Priority {goal.priority}]" if goal.priority > 1 else "[Primary]" - optional_badge = " *(Optional)*" if goal.optional else "" - sections.append(f"### {goal.name} {priority_badge}{optional_badge}\n") - sections.append(goal.description) - sections.append(f"\n**Success Condition:** {goal.success_condition}\n") - - # Constraints - if self.constraints: - sections.append("## Constraints\n") - sections.append("Your agent must operate within these rules:\n") - sections.append("| Constraint | Description | Penalty |") - sections.append("|------------|-------------|---------|") - for constraint in self.constraints: - penalty = constraint.penalty or "None" - sections.append(f"| {constraint.name} | {constraint.description} | {penalty} |") - sections.append("") - - # Available tools - if self.available_tools: - sections.append("## Available Tools\n") - sections.append("Your agent can use these tools:\n") - for tool_name in self.available_tools: - sections.append(f"- `{tool_name}`") - sections.append("\nSee the [Tool Reference](../api_reference/tools.md) for details.\n") - - # Perception - if self.perception_info: - sections.append("## What Your Agent Can See\n") - for key, value in self.perception_info.items(): - sections.append(f"- **{key}**: {value}") - sections.append("") - - # Resources and hazards - if self.resource_types: - sections.append("## Resources\n") - sections.append("| Type | Name | Description |") - sections.append("|------|------|-------------|") - for rt in self.resource_types: - sections.append(f"| {rt['type']} | {rt['name']} | {rt.get('description', '')} |") - sections.append("") - - if self.hazard_types: - sections.append("## Hazards\n") - sections.append("| Type | Name | Damage | Description |") - sections.append("|------|------|--------|-------------|") - for ht in self.hazard_types: - sections.append( - f"| {ht['type']} | {ht['name']} | {ht.get('damage', '?')} | " - f"{ht.get('description', '')} |" - ) - sections.append("") - - # Metrics - if self.metrics: - sections.append("## Success Metrics\n") - sections.append("Your agent will be evaluated on:\n") - sections.append("| Metric | Description | Goal |") - sections.append("|--------|-------------|------|") - for metric in self.metrics: - goal_str = metric.optimize - if metric.optimize == "target" and metric.target_value is not None: - goal_str = f"Target: {metric.target_value}" - unit_str = f" ({metric.unit})" if metric.unit else "" - sections.append(f"| {metric.name}{unit_str} | {metric.description} | {goal_str} |") - sections.append("") - - # Hints - if self.hints: - sections.append("## Hints\n") - sections.append( - "
\nClick to reveal hints (try without them first!)\n" - ) - for i, hint in enumerate(self.hints, 1): - sections.append(f"{i}. {hint}") - sections.append("\n
\n") - - return "\n".join(sections) - - def to_dict(self) -> dict: - """Convert to dictionary for serialization.""" - return { - "name": self.name, - "id": self.id, - "tier": self.tier, - "version": self.version, - "description": self.description, - "backstory": self.backstory, - "goals": [ - { - "name": g.name, - "description": g.description, - "success_condition": g.success_condition, - "priority": g.priority, - "optional": g.optional, - } - for g in self.goals - ], - "constraints": [ - { - "name": c.name, - "description": c.description, - "penalty": c.penalty, - } - for c in self.constraints - ], - "available_tools": self.available_tools, - "metrics": [ - { - "name": m.name, - "description": m.description, - "unit": m.unit, - "optimize": m.optimize, - "target_value": m.target_value, - } - for m in self.metrics - ], - "success_threshold": self.success_threshold, - "perception_info": self.perception_info, - "hints": self.hints, - "learning_objectives": self.learning_objectives, - "resource_types": self.resource_types, - "hazard_types": self.hazard_types, - } - - @classmethod - def from_dict(cls, data: dict) -> "ScenarioDefinition": - """Create from dictionary.""" - goals = [ - Goal( - name=g["name"], - description=g["description"], - success_condition=g["success_condition"], - priority=g.get("priority", 1), - optional=g.get("optional", False), - ) - for g in data.get("goals", []) - ] - - constraints = [ - Constraint( - name=c["name"], - description=c["description"], - penalty=c.get("penalty"), - ) - for c in data.get("constraints", []) - ] - - metrics = [ - Metric( - name=m["name"], - description=m["description"], - unit=m.get("unit"), - optimize=m.get("optimize", "maximize"), - target_value=m.get("target_value"), - ) - for m in data.get("metrics", []) - ] - - return cls( - name=data["name"], - id=data["id"], - tier=data.get("tier", 1), - version=data.get("version", "1.0.0"), - description=data["description"], - backstory=data.get("backstory"), - goals=goals, - constraints=constraints, - available_tools=data.get("available_tools", []), - metrics=metrics, - success_threshold=data.get("success_threshold", {}), - perception_info=data.get("perception_info", {}), - hints=data.get("hints", []), - learning_objectives=data.get("learning_objectives", []), - resource_types=data.get("resource_types", []), - hazard_types=data.get("hazard_types", []), - ) diff --git a/python/backends/__init__.py b/python/backends/__init__.py deleted file mode 100644 index 27e5368..0000000 --- a/python/backends/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -LLM Backend Adapters for Agent Arena -""" - -from .base import BackendConfig, BaseBackend -from .llama_cpp_backend import LlamaCppBackend -from .vllm_backend import VLLMBackend, VLLMBackendConfig - -__all__ = ["BaseBackend", "BackendConfig", "LlamaCppBackend", "VLLMBackend", "VLLMBackendConfig"] diff --git a/python/backends/base.py b/python/backends/base.py deleted file mode 100644 index d5dd947..0000000 --- a/python/backends/base.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Base backend interface for LLM backends. -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any - - -@dataclass -class BackendConfig: - """Configuration for LLM backend.""" - - model_path: str - temperature: float = 0.7 - max_tokens: int = 512 - top_p: float = 0.9 - top_k: int = 40 - n_gpu_layers: int = 0 # Number of layers to offload to GPU (0 = CPU only, -1 = all) - - -@dataclass -class GenerationResult: - """Result from LLM generation.""" - - text: str - tokens_used: int - finish_reason: str # "stop", "length", "error" - metadata: dict[str, Any] - - -class BaseBackend(ABC): - """ - Abstract base class for LLM backends. - - All backend implementations must inherit from this class. - """ - - def __init__(self, config: BackendConfig): - """ - Initialize the backend. - - Args: - config: Backend configuration - """ - self.config = config - - @abstractmethod - def generate( - self, - prompt: str, - temperature: float | None = None, - max_tokens: int | None = None, - system_prompt: str | None = None, - ) -> GenerationResult: - """ - Generate text from prompt. - - Args: - prompt: Input prompt - temperature: Override temperature (optional) - max_tokens: Override max tokens (optional) - system_prompt: Optional system message for chat formatting - - Returns: - GenerationResult with generated text and metadata - """ - pass - - @abstractmethod - def generate_with_tools( - self, - prompt: str, - tools: list[dict[str, Any]], - temperature: float | None = None, - ) -> GenerationResult: - """ - Generate with function/tool calling support. - - Args: - prompt: Input prompt - tools: List of available tool schemas - temperature: Override temperature (optional) - - Returns: - GenerationResult with tool call or text - """ - pass - - @abstractmethod - def is_available(self) -> bool: - """ - Check if backend is available and ready. - - Returns: - True if backend is loaded and ready - """ - pass - - @abstractmethod - def unload(self) -> None: - """Unload the model and free resources.""" - pass diff --git a/python/backends/llama_cpp_backend.py b/python/backends/llama_cpp_backend.py deleted file mode 100644 index ce9d6ca..0000000 --- a/python/backends/llama_cpp_backend.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -llama.cpp backend adapter. -""" - -import json -import logging -from typing import TYPE_CHECKING, Any, cast - -from .base import BackendConfig, BaseBackend, GenerationResult - -if TYPE_CHECKING: - from llama_cpp import Llama - -logger = logging.getLogger(__name__) - - -class LlamaCppBackend(BaseBackend): - """Backend adapter for llama.cpp.""" - - llm: "Llama | None" - - def __init__(self, config: BackendConfig): - """Initialize llama.cpp backend.""" - super().__init__(config) - self.llm = None - self._load_model() - - def _load_model(self) -> None: - """Load the llama.cpp model.""" - try: - from llama_cpp import Llama - - logger.info(f"Loading model from {self.config.model_path}") - - # Use GPU layers from config - n_gpu_layers = getattr(self.config, "n_gpu_layers", 0) - - if n_gpu_layers > 0: - logger.info(f"Offloading {n_gpu_layers} layers to GPU") - elif n_gpu_layers == -1: - logger.info("Offloading all layers to GPU") - else: - logger.info("Using CPU only (no GPU offload)") - - self.llm = Llama( - model_path=self.config.model_path, - n_ctx=4096, # Context window - n_threads=8, # CPU threads - n_gpu_layers=n_gpu_layers, # GPU layers (0 = CPU only, -1 = all) - verbose=False, # Reduce output noise - ) - - logger.info("Model loaded successfully") - - except ImportError: - logger.error( - "llama-cpp-python not installed. Install with: pip install llama-cpp-python" - ) - raise - except Exception as e: - logger.error(f"Failed to load model: {e}") - raise - - def generate( - self, - prompt: str, - temperature: float | None = None, - max_tokens: int | None = None, - system_prompt: str | None = None, - ) -> GenerationResult: - """Generate text from prompt using chat completion API. - - Args: - prompt: The user message / prompt text - temperature: Sampling temperature override - max_tokens: Max tokens override - system_prompt: Optional system message (prepended to chat) - """ - if not self.llm: - raise RuntimeError("Model not loaded") - - temp = temperature if temperature is not None else self.config.temperature - max_tok = max_tokens if max_tokens is not None else self.config.max_tokens - - try: - # Build chat messages so llama.cpp applies the correct chat template - messages: list[dict[str, str]] = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - response = self.llm.create_chat_completion( - messages=messages, - temperature=temp, - max_tokens=max_tok, - top_p=self.config.top_p, - top_k=self.config.top_k, - ) - - # Cast response to dict since we're not streaming - resp = cast(dict[str, Any], response) - text = resp["choices"][0]["message"]["content"] or "" - tokens_used = resp["usage"]["total_tokens"] - finish_reason = str(resp["choices"][0].get("finish_reason", "stop")) - - return GenerationResult( - text=text, - tokens_used=tokens_used, - finish_reason=finish_reason, - metadata={"model": self.config.model_path}, - ) - - except Exception as e: - logger.error(f"Generation error: {e}") - return GenerationResult( - text="", - tokens_used=0, - finish_reason="error", - metadata={"error": str(e)}, - ) - - def generate_with_tools( - self, - prompt: str, - tools: list[dict[str, Any]], - temperature: float | None = None, - ) -> GenerationResult: - """Generate with function calling support.""" - if not self.llm: - raise RuntimeError("Model not loaded") - - # The prompt already contains the system instructions and user data. - # Just pass it through to generate() — the caller (agent.py) builds - # the full prompt with system + decision template. - result = self.generate(prompt, temperature) - - # Try to parse JSON from result - try: - text = result.text.strip() - if text.startswith("```json"): - text = text[7:] - if text.endswith("```"): - text = text[:-3] - - parsed = json.loads(text.strip()) - result.metadata["parsed_tool_call"] = parsed - - except json.JSONDecodeError: - logger.debug("Backend JSON parse failed (expected for CoT format)") - result.metadata["parse_error"] = True - - return result - - def is_available(self) -> bool: - """Check if backend is ready.""" - return self.llm is not None - - def unload(self) -> None: - """Unload the model.""" - if self.llm: - del self.llm - self.llm = None - logger.info("Model unloaded") diff --git a/python/backends/vllm_backend.py b/python/backends/vllm_backend.py deleted file mode 100644 index 39c480f..0000000 --- a/python/backends/vllm_backend.py +++ /dev/null @@ -1,350 +0,0 @@ -""" -vLLM backend adapter using OpenAI-compatible API. - -vLLM is a high-throughput inference engine that provides an OpenAI-compatible -REST API. This backend connects to a vLLM server instance. -""" - -import json -import logging -from typing import Any - -from openai import OpenAI - -from .base import BackendConfig, BaseBackend, GenerationResult - -logger = logging.getLogger(__name__) - - -class VLLMBackendConfig(BackendConfig): - """Extended configuration for vLLM backend.""" - - def __init__( - self, - model_path: str, - api_base: str = "http://localhost:8000/v1", - api_key: str = "EMPTY", - temperature: float = 0.7, - max_tokens: int = 512, - top_p: float = 0.9, - top_k: int = 40, - ): - """ - Initialize vLLM backend config. - - Args: - model_path: Model identifier (e.g., "meta-llama/Llama-2-7b-chat-hf") - api_base: Base URL for vLLM server - api_key: API key (vLLM uses "EMPTY" by default) - temperature: Sampling temperature - max_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter - """ - super().__init__( - model_path=model_path, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - ) - self.api_base = api_base - self.api_key = api_key - - -class VLLMBackend(BaseBackend): - """ - Backend adapter for vLLM inference server. - - This backend connects to a running vLLM server using the OpenAI-compatible API. - The vLLM server must be started separately before using this backend. - - Example: - Start vLLM server: - ```bash - python -m vllm.entrypoints.openai.api_server \\ - --model meta-llama/Llama-2-7b-chat-hf \\ - --port 8000 - ``` - - Then use this backend: - ```python - config = VLLMBackendConfig( - model_path="meta-llama/Llama-2-7b-chat-hf", - api_base="http://localhost:8000/v1" - ) - backend = VLLMBackend(config) - result = backend.generate("Hello, world!") - ``` - """ - - def __init__(self, config: VLLMBackendConfig): - """ - Initialize vLLM backend. - - Args: - config: vLLM backend configuration - """ - super().__init__(config) - self.config: VLLMBackendConfig = config - self.client: OpenAI | None = None - self._connect() - - def _connect(self) -> None: - """Connect to vLLM server.""" - try: - logger.info(f"Connecting to vLLM server at {self.config.api_base}") - - self.client = OpenAI( - api_key=self.config.api_key, - base_url=self.config.api_base, - ) - - # Test connection with a simple request - try: - models = self.client.models.list() - logger.info(f"Connected to vLLM. Available models: {[m.id for m in models.data]}") - except Exception as e: - logger.warning(f"Could not list models (server may not be ready): {e}") - - except Exception as e: - logger.error(f"Failed to connect to vLLM server: {e}") - raise - - def generate( - self, - prompt: str, - temperature: float | None = None, - max_tokens: int | None = None, - system_prompt: str | None = None, - ) -> GenerationResult: - """ - Generate text from prompt using vLLM. - - Args: - prompt: Input prompt - temperature: Override temperature (optional) - max_tokens: Override max tokens (optional) - system_prompt: Optional system message (not used for completion API) - - Returns: - GenerationResult with generated text and metadata - """ - if not self.client: - raise RuntimeError("vLLM client not connected") - - temp = temperature if temperature is not None else self.config.temperature - max_tok = max_tokens if max_tokens is not None else self.config.max_tokens - - try: - response = self.client.completions.create( - model=self.config.model_path, - prompt=prompt, - temperature=temp, - max_tokens=max_tok, - top_p=self.config.top_p, - extra_body={"top_k": self.config.top_k}, - ) - - text = response.choices[0].text - tokens_used = response.usage.total_tokens if response.usage else 0 - - return GenerationResult( - text=text, - tokens_used=tokens_used, - finish_reason=response.choices[0].finish_reason or "stop", - metadata={ - "model": self.config.model_path, - "api_base": self.config.api_base, - }, - ) - - except Exception as e: - logger.error(f"Generation error: {e}") - return GenerationResult( - text="", - tokens_used=0, - finish_reason="error", - metadata={"error": str(e)}, - ) - - def generate_with_tools( - self, - prompt: str, - tools: list[dict[str, Any]], - temperature: float | None = None, - ) -> GenerationResult: - """ - Generate with function calling support. - - vLLM supports OpenAI-style function calling for compatible models. - - Args: - prompt: Input prompt - tools: List of available tool schemas - temperature: Override temperature (optional) - - Returns: - GenerationResult with tool call or text - """ - if not self.client: - raise RuntimeError("vLLM client not connected") - - temp = temperature if temperature is not None else self.config.temperature - - try: - # Convert tool schemas to OpenAI format - openai_tools = [] - for tool in tools: - openai_tools.append( - { - "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": tool.get("parameters", {}), - }, - } - ) - - # Use chat completions API for function calling - response = self.client.chat.completions.create( # type: ignore[call-overload] - model=self.config.model_path, - messages=[{"role": "user", "content": prompt}], - tools=openai_tools, - tool_choice="auto", - temperature=temp, - max_tokens=self.config.max_tokens, - ) - - choice = response.choices[0] - tokens_used = response.usage.total_tokens if response.usage else 0 - - # Check if model returned a tool call - if choice.message.tool_calls: - tool_call = choice.message.tool_calls[0] - text = choice.message.content or "" - - return GenerationResult( - text=text, - tokens_used=tokens_used, - finish_reason=choice.finish_reason or "stop", - metadata={ - "model": self.config.model_path, - "tool_call": { - "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), - }, - }, - ) - else: - # No tool call, return regular text - text = choice.message.content or "" - return GenerationResult( - text=text, - tokens_used=tokens_used, - finish_reason=choice.finish_reason or "stop", - metadata={"model": self.config.model_path}, - ) - - except Exception as e: - logger.error(f"Tool generation error: {e}") - - # Fallback to prompt-based tool calling - logger.info("Falling back to prompt-based tool calling") - return self._generate_with_tools_fallback(prompt, tools, temp) - - def _generate_with_tools_fallback( - self, - prompt: str, - tools: list[dict[str, Any]], - temperature: float, - ) -> GenerationResult: - """ - Fallback method for tool calling using prompt engineering. - - Used when the model doesn't support native function calling. - - Args: - prompt: Input prompt - tools: List of available tool schemas - temperature: Sampling temperature - - Returns: - GenerationResult with tool call attempt - """ - # Build a prompt that includes tool schemas - tool_descriptions = [] - for tool in tools: - tool_desc = f"- {tool['name']}: {tool['description']}" - if "parameters" in tool: - tool_desc += f"\n Parameters: {json.dumps(tool['parameters'])}" - tool_descriptions.append(tool_desc) - - tools_text = "\n".join(tool_descriptions) - - enhanced_prompt = f"""{prompt} - -Available tools: -{tools_text} - -Respond with a JSON object in the format: -{{"tool": "tool_name", "params": {{}}, "reasoning": "why this tool"}} - -Or if no tool is needed: -{{"tool": "none", "reasoning": "explanation"}} -""" - - result = self.generate(enhanced_prompt, temperature) - - # Try to parse JSON from result - try: - text = result.text.strip() - # Remove markdown code blocks if present - if text.startswith("```json"): - text = text[7:] - elif text.startswith("```"): - text = text[3:] - if text.endswith("```"): - text = text[:-3] - - parsed = json.loads(text.strip()) - result.metadata["parsed_tool_call"] = parsed - - except json.JSONDecodeError: - # This is expected for Chain-of-Thought format (THINKING: ... ACTION: ...) - # The downstream AgentDecision.from_llm_response() handles CoT parsing - logger.debug("Backend JSON parse failed (expected for CoT format)") - result.metadata["parse_error"] = True - - return result - - def is_available(self) -> bool: - """ - Check if vLLM server is available and ready. - - Returns: - True if server is connected and responsive - """ - if not self.client: - return False - - try: - # Try to list models as a health check - self.client.models.list() - return True - except Exception as e: - logger.debug(f"vLLM availability check failed: {e}") - return False - - def unload(self) -> None: - """ - Disconnect from vLLM server. - - Note: This only closes the client connection. The vLLM server - continues running and must be stopped separately if needed. - """ - if self.client: - self.client.close() - self.client = None - logger.info("Disconnected from vLLM server") diff --git a/python/ipc/server.py b/python/ipc/server.py deleted file mode 100644 index 76b0a5c..0000000 --- a/python/ipc/server.py +++ /dev/null @@ -1,754 +0,0 @@ -""" -IPC Server - FastAPI server for handling Godot <-> Python communication. - -This server receives perception data from Godot, processes agent decisions, -and returns actions to execute in the simulation. -""" - -import logging -import time -from contextlib import asynccontextmanager -from typing import Any - -from fastapi import FastAPI, HTTPException - -from agent_runtime.behavior import AgentBehavior -from agent_runtime.runtime import AgentRuntime -from agent_runtime.schemas import ToolSchema -from agent_runtime.tool_dispatcher import ToolDispatcher -from tools import ( - register_inventory_tools, - register_movement_tools, - register_navigation_tools, - register_world_query_tools, -) - -from .converters import decision_to_action, perception_to_observation -from .messages import ( - ActionMessage, - TickRequest, - TickResponse, - ToolExecutionRequest, - ToolExecutionResponse, -) - -logger = logging.getLogger(__name__) - - -class IPCServer: - """ - IPC Server for handling communication between Godot and Python. - - Runs a FastAPI server that receives tick requests and returns agent actions. - """ - - def __init__( - self, - runtime: AgentRuntime, - behaviors: dict | None = None, - default_behavior: "AgentBehavior | None" = None, - host: str = "127.0.0.1", - port: int = 5000, - ): - """ - Initialize the IPC server. - - Args: - runtime: AgentRuntime instance to process agent decisions - behaviors: Dictionary of agent_id -> AgentBehavior instances - default_behavior: Default behavior to use for unregistered agents - host: Host address to bind to - port: Port to listen on - """ - self.runtime = runtime - self.behaviors = behaviors if behaviors is not None else {} - self.default_behavior = default_behavior - self.host = host - self.port = port - self.app: FastAPI | None = None - self.tool_dispatcher = ToolDispatcher() - self._register_all_tools() - self.metrics = { - "total_ticks": 0, - "total_agents_processed": 0, - "avg_tick_time_ms": 0.0, - "total_tools_executed": 0, - "total_observations_processed": 0, - } - # Track last tick per agent to detect episode resets - self._last_tick_per_agent: dict[str, int] = {} - - def _register_all_tools(self) -> None: - """Register all available tools with the dispatcher.""" - register_movement_tools(self.tool_dispatcher) - register_inventory_tools(self.tool_dispatcher) - register_world_query_tools(self.tool_dispatcher) - register_navigation_tools(self.tool_dispatcher) - logger.info(f"Registered {len(self.tool_dispatcher.tools)} tools") - - def _make_mock_decision(self, observation: dict[str, Any]) -> dict[str, Any]: - """ - Generate a mock decision based on observation using rule-based logic. - - This is a simple decision-making system for testing the observation-decision - loop without requiring LLM inference. - - Priority: - 1. Avoid nearby hazards (distance < 3.0) - 2. Move to nearest resource (distance < 5.0) - 3. Default to idle - - Args: - observation: Observation dictionary with nearby_resources and nearby_hazards - - Returns: - Decision dictionary with tool, params, and reasoning - """ - nearby_resources = observation.get("nearby_resources", []) - nearby_hazards = observation.get("nearby_hazards", []) - - # Priority 1: Avoid hazards that are too close - for hazard in nearby_hazards: - distance = hazard.get("distance", float("inf")) - if distance < 3.0: - hazard_pos = hazard.get("position", [0, 0, 0]) - hazard_type = hazard.get("type", "unknown") - - # Calculate a safe position away from the hazard using move_to - agent_pos = observation.get("position", [0, 0, 0]) - - # Vector from hazard to agent - dx = agent_pos[0] - hazard_pos[0] - dz = agent_pos[2] - hazard_pos[2] - - # Normalize and scale to move 5 units away from hazard - length = (dx**2 + dz**2) ** 0.5 - if length > 0: - dx = (dx / length) * 5.0 - dz = (dz / length) * 5.0 - else: - # If on top of hazard, move in arbitrary direction - dx, dz = 5.0, 0.0 - - safe_position = [ - hazard_pos[0] + dx, - agent_pos[1], # Keep same Y - hazard_pos[2] + dz, - ] - - return { - "tool": "move_to", - "params": {"target_position": safe_position, "speed": 2.0}, - "reasoning": f"Avoiding nearby {hazard_type} hazard at distance {distance:.1f}", - } - - # Priority 2: Move to nearest resource if within range - if nearby_resources: - # Find closest resource - closest_resource = min(nearby_resources, key=lambda r: r.get("distance", float("inf"))) - distance = closest_resource.get("distance", float("inf")) - - if distance < 5.0: - resource_pos = closest_resource.get("position", [0, 0, 0]) - resource_type = closest_resource.get("type", "unknown") - resource_name = closest_resource.get("name", "resource") - return { - "tool": "move_to", - "params": {"target_position": resource_pos, "speed": 1.5}, - "reasoning": f"Moving to collect {resource_type} ({resource_name}) at distance {distance:.1f}", - } - - # Default: Idle (no immediate actions needed) - return { - "tool": "idle", - "params": {}, - "reasoning": "No immediate actions needed - exploring environment", - } - - def create_app(self) -> FastAPI: - """Create and configure the FastAPI application.""" - - @asynccontextmanager - async def lifespan(app: FastAPI): - """Lifespan context manager for startup/shutdown.""" - logger.info("Starting IPC server...") - self.runtime.start() - yield - logger.info("Shutting down IPC server...") - self.runtime.stop() - - app = FastAPI( - title="Agent Arena IPC Server", - description="Communication bridge between Godot simulation and Python agents", - version="0.1.0", - lifespan=lifespan, - ) - - @app.get("/") - async def root(): - """Health check endpoint.""" - return { - "status": "running", - "agents": len(self.runtime.agents), - "metrics": self.metrics, - } - - @app.get("/health") - async def health(): - """Health check endpoint.""" - return {"status": "ok", "agents": len(self.runtime.agents)} - - @app.post("/tick") - async def process_tick(request_data: dict[str, Any]) -> dict[str, Any]: - """ - Process a simulation tick. - - Receives perception data for all agents, processes decisions, - and returns actions to execute. - - Args: - request_data: Tick request containing agent perceptions - - Returns: - Tick response containing agent actions - """ - start_time = time.time() - - try: - # Parse request - tick_request = TickRequest.from_dict(request_data) - tick = tick_request.tick - - logger.info( - f"[/tick] Processing tick {tick} with {len(tick_request.perceptions)} agents" - ) - - # Get tool schemas for agents - tool_schemas = [] - for name, schema in self.tool_dispatcher.schemas.items(): - tool_schemas.append( - ToolSchema( - name=schema.name, - description=schema.description, - parameters=schema.parameters, - ) - ) - - # Process each agent using registered behaviors - action_messages = [] - for perception in tick_request.perceptions: - agent_id = perception.agent_id - - # Convert perception to Observation - observation = perception_to_observation(perception) - - # Get behavior for this agent (or use default) - behavior = self.behaviors.get(agent_id) or self.default_behavior - - # If no specific behavior, check for default behavior - if behavior is None and "_default" in self.behaviors: - behavior = self.behaviors["_default"] - logger.debug(f"Using default behavior for agent {agent_id}") - - if behavior: - # Call behavior.decide() with Observation and tools - try: - # Detect episode reset (tick went backwards or restarted) - last_tick = self._last_tick_per_agent.get(agent_id, -1) - if tick <= last_tick and tick <= 1: - # New episode detected - clear memory - logger.info( - f"Episode reset detected for {agent_id} (tick {tick} <= {last_tick})" - ) - behavior.on_episode_start() - self._last_tick_per_agent[agent_id] = tick - - # Set trace context before decide() for reasoning trace logging - behavior._set_trace_context(agent_id, tick) - - # Update world map with current observation (spatial memory) - behavior._update_world_map(observation) - - decision = behavior.decide(observation, tool_schemas) - - # End trace after decide() to persist trace to disk - behavior._end_trace() - - # Convert decision to ActionMessage - action_msg = decision_to_action(decision, agent_id, tick) - action_messages.append(action_msg) - - logger.info( - f"[/tick] Agent {agent_id} decided: {decision.tool} - {decision.reasoning}" - ) - except Exception as e: - logger.error( - f"Error in behavior.decide() for agent {agent_id}: {e}", - exc_info=True, - ) - # End trace even on error - behavior._end_trace() - # Fallback to idle - from agent_runtime.schemas import AgentDecision - - decision = AgentDecision.idle(reasoning=f"Error: {str(e)}") - action_msg = decision_to_action(decision, agent_id, tick) - action_messages.append(action_msg) - else: - # No behavior registered, use mock decision - logger.warning( - f"No behavior registered for agent {agent_id}, using mock decision" - ) - mock_obs = { - "agent_id": agent_id, - "position": perception.position, - "nearby_resources": perception.custom_data.get("nearby_resources", []), - "nearby_hazards": perception.custom_data.get("nearby_hazards", []), - } - decision_dict = self._make_mock_decision(mock_obs) - action_msg = ActionMessage( - agent_id=agent_id, - tick=tick, - tool=decision_dict["tool"], - params=decision_dict["params"], - reasoning=decision_dict["reasoning"], - ) - action_messages.append(action_msg) - - # Calculate metrics - elapsed_ms = (time.time() - start_time) * 1000 - self.metrics["total_ticks"] += 1 - self.metrics["total_agents_processed"] += len(tick_request.perceptions) - self.metrics["avg_tick_time_ms"] = ( - self.metrics["avg_tick_time_ms"] * 0.9 + elapsed_ms * 0.1 - ) - - # Build response - response = TickResponse( - tick=tick, - actions=action_messages, - metrics={ - "tick_time_ms": elapsed_ms, - "agents_processed": len(tick_request.perceptions), - "actions_generated": len(action_messages), - }, - ) - - logger.info( - f"[/tick] Tick {tick} processed in {elapsed_ms:.2f}ms, " - f"{len(action_messages)} actions generated" - ) - - return dict(response.to_dict()) - - except Exception as e: - logger.error(f"Error processing tick: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/agents/register") - async def register_agent(agent_data: dict[str, Any]) -> dict[str, str]: - """ - Register a new agent with the runtime. - - Args: - agent_data: Agent configuration data - - Returns: - Success message with agent ID - """ - try: - agent_id = agent_data.get("agent_id") - if not agent_id: - raise HTTPException(status_code=400, detail="agent_id is required") - - # Note: Agent instantiation would happen here - # For now, we'll just acknowledge the registration - logger.info(f"Agent registration request received for {agent_id}") - - return {"status": "success", "agent_id": agent_id} - - except Exception as e: - logger.error(f"Error registering agent: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/tools/execute") - async def execute_tool(request_data: dict[str, Any]) -> dict[str, Any]: - """ - Execute a tool requested from Godot. - - Args: - request_data: Tool execution request - - Returns: - Tool execution response with result or error - """ - try: - # Parse request - tool_request = ToolExecutionRequest.from_dict(request_data) - - logger.debug( - f"Executing tool '{tool_request.tool_name}' " - f"for agent '{tool_request.agent_id}' at tick {tool_request.tick}" - ) - - # Execute the tool through dispatcher - result = self.tool_dispatcher.execute_tool( - tool_request.tool_name, tool_request.params - ) - - # Update metrics - self.metrics["total_tools_executed"] += 1 - - # Build response - response = ToolExecutionResponse( - success=result.get("success", False), - result=result.get("result"), - error=result.get("error", ""), - ) - - logger.debug( - f"Tool '{tool_request.tool_name}' executed: success={response.success}" - ) - - return dict(response.to_dict()) - - except Exception as e: - logger.error(f"Error executing tool: {e}", exc_info=True) - return dict(ToolExecutionResponse(success=False, error=str(e)).to_dict()) - - @app.get("/tools/list") - async def list_tools() -> dict[str, Any]: - """Get list of available tools and their schemas.""" - logger.info("[/tools/list] Tools requested") - schemas = {} - for name, schema in self.tool_dispatcher.schemas.items(): - schemas[name] = { - "name": schema.name, - "description": schema.description, - "parameters": schema.parameters, - "returns": schema.returns, - } - logger.info(f"[/tools/list] Returning {len(schemas)} tools") - return {"tools": schemas, "count": len(schemas)} - - @app.get("/metrics") - async def get_metrics(): - """Get server performance metrics.""" - return self.metrics - - @app.get("/inspector/requests") - async def get_inspector_requests( - agent_id: str | None = None, - tick: int | None = None, - tick_start: int | None = None, - tick_end: int | None = None, - ): - """ - Get captured reasoning traces from the TraceStore. - - Query parameters: - agent_id: Filter by specific agent (optional) - tick: Get data for a specific tick (optional) - tick_start: Minimum tick number (inclusive, optional) - tick_end: Maximum tick number (inclusive, optional) - - Returns: - List of reasoning traces with full decision-making data - """ - from agent_runtime.reasoning_trace import get_global_trace_store - - inspector = get_global_trace_store() - - # Single capture query - if agent_id and tick is not None: - capture = inspector.get_capture(agent_id, tick) - if not capture: - raise HTTPException( - status_code=404, detail=f"No capture found for agent {agent_id} tick {tick}" - ) - return {"captures": [capture.to_dict()], "count": 1} - - # Agent-specific query - if agent_id: - captures = inspector.get_captures_for_agent(agent_id, tick_start, tick_end) - return {"captures": [c.to_dict() for c in captures], "count": len(captures)} - - # All captures query - captures = inspector.get_all_captures(tick_start, tick_end) - return {"captures": [c.to_dict() for c in captures], "count": len(captures)} - - @app.delete("/inspector/requests") - async def clear_inspector_requests(): - """Clear all captured reasoning traces from the TraceStore.""" - from agent_runtime.reasoning_trace import get_global_trace_store - - inspector = get_global_trace_store() - inspector.clear() - return {"status": "cleared", "message": "All inspector captures have been cleared"} - - @app.get("/inspector/config") - async def get_inspector_config(): - """Get current TraceStore configuration.""" - from agent_runtime.reasoning_trace import get_global_trace_store - - inspector = get_global_trace_store() - return { - "enabled": inspector.enabled, - "max_entries": inspector.max_entries, - "log_to_file": inspector.log_to_file, - "log_dir": str(inspector.log_dir), - "current_captures": len(inspector.traces), - "episode_id": inspector.episode_id, - } - - @app.get("/traces/episode/{episode_id}") - async def get_episode_traces(episode_id: str): - """ - Get all reasoning traces for a specific episode. - - This loads traces from the JSONL file on disk. - - Args: - episode_id: Episode identifier (e.g., "ep_20260131_120000") - - Returns: - List of all traces from that episode - """ - from agent_runtime.reasoning_trace import get_global_trace_store - - trace_store = get_global_trace_store() - traces = trace_store.get_episode_traces(episode_id) - - if not traces: - raise HTTPException( - status_code=404, detail=f"No traces found for episode {episode_id}" - ) - - return { - "episode_id": episode_id, - "traces": [t.to_dict() for t in traces], - "count": len(traces), - } - - @app.post("/traces/episode/start") - async def start_episode(episode_id: str | None = None): - """ - Start a new episode for trace collection. - - Args: - episode_id: Optional custom episode ID (auto-generated if not provided) - - Returns: - The episode ID that was started - """ - import datetime - - from agent_runtime.reasoning_trace import get_global_trace_store - - trace_store = get_global_trace_store() - - if not episode_id: - timestamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S") - episode_id = f"ep_{timestamp}" - - trace_store.start_episode(episode_id) - - return {"episode_id": episode_id, "message": f"Started episode {episode_id}"} - - @app.post("/traces/episode/end") - async def end_episode(): - """ - End the current episode and close the trace file. - - Returns: - The episode ID that was ended - """ - from agent_runtime.reasoning_trace import get_global_trace_store - - trace_store = get_global_trace_store() - episode_id = trace_store.episode_id - trace_store.end_episode() - - return {"episode_id": episode_id, "message": f"Ended episode {episode_id}"} - - @app.post("/observe") - async def process_observation(observation: dict[str, Any]) -> dict[str, Any]: - """ - Process a single observation and return a decision. - - This endpoint supports both registered behaviors and mock logic: - - If a behavior is registered for the agent_id, it will be used - - Otherwise, falls back to rule-based mock logic - - Args: - observation: Observation data containing: - - agent_id: Agent identifier - - position: [x, y, z] position - - nearby_resources: List of visible resources - - nearby_hazards: List of nearby hazards - - Returns: - Decision dictionary with tool, params, and reasoning - """ - try: - agent_id = observation.get("agent_id", "unknown") - - logger.info(f"[/observe] Processing observation for agent '{agent_id}'") - logger.debug(f"Position: {observation.get('position')}") - logger.debug(f"Resources: {len(observation.get('nearby_resources', []))}") - logger.debug(f"Hazards: {len(observation.get('nearby_hazards', []))}") - - # Check if we have a registered behavior for this agent (or use default) - behavior = self.behaviors.get(agent_id) or self.default_behavior - - if behavior: - # Log which behavior type is being used - behavior_type = "registered" if agent_id in self.behaviors else "default" - logger.info(f"[/observe] Using {behavior_type} behavior for agent '{agent_id}'") - # Convert observation dict to Observation object - from ipc.messages import PerceptionMessage - - from .converters import perception_to_observation - - # Create PerceptionMessage from observation dict - perception = PerceptionMessage( - agent_id=agent_id, - tick=observation.get("tick", 0), - position=observation.get("position", [0, 0, 0]), - rotation=observation.get("rotation", [0, 0, 0]), - velocity=observation.get("velocity", [0, 0, 0]), - custom_data={ - "nearby_resources": observation.get("nearby_resources", []), - "nearby_hazards": observation.get("nearby_hazards", []), - "inventory": observation.get("inventory", []), - "health": observation.get("health", 100.0), - "energy": observation.get("energy", 100.0), - }, - ) - - obs = perception_to_observation(perception) - - # Get tool schemas - tool_schemas = [] - for name, schema in self.tool_dispatcher.schemas.items(): - tool_schemas.append( - ToolSchema( - name=schema.name, - description=schema.description, - parameters=schema.parameters, - ) - ) - - try: - # Set trace context before decide() for reasoning trace logging - tick = observation.get("tick", 0) - behavior._set_trace_context(agent_id, tick) - - # Call behavior.decide() - agent_decision = behavior.decide(obs, tool_schemas) - - # End trace after decide() to persist trace to disk - behavior._end_trace() - - decision = { - "tool": agent_decision.tool, - "params": agent_decision.params, - "reasoning": agent_decision.reasoning or "Agent decision", - } - except Exception as e: - logger.error(f"Error in behavior.decide(): {e}", exc_info=True) - # End trace even on error - behavior._end_trace() - decision = { - "tool": "idle", - "params": {}, - "reasoning": f"Error: {str(e)}", - } - else: - # Generate mock decision using rule-based logic - decision = self._make_mock_decision(observation) - - # Update metrics - self.metrics["total_observations_processed"] += 1 - - logger.info( - f"Agent {agent_id} decision: {decision['tool']} - {decision['reasoning']}" - ) - - return { - "agent_id": agent_id, - "tool": decision["tool"], - "params": decision["params"], - "reasoning": decision["reasoning"], - } - - except Exception as e: - logger.error(f"Error processing observation: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) - - self.app = app - return app - - def run(self): - """Run the IPC server (blocking).""" - import uvicorn - - if not self.app: - self.create_app() - - logger.info(f"Starting IPC server on {self.host}:{self.port}") - uvicorn.run(self.app, host=self.host, port=self.port, log_level="info") - - async def run_async(self): - """Run the IPC server asynchronously.""" - import uvicorn - - if not self.app: - self.create_app() - - config = uvicorn.Config(self.app, host=self.host, port=self.port, log_level="info") - server = uvicorn.Server(config) - await server.serve() - - -def create_server( - runtime: AgentRuntime | None = None, - behaviors: dict | None = None, - default_behavior: AgentBehavior | None = None, - host: str = "127.0.0.1", - port: int = 5000, -) -> IPCServer: - """ - Factory function to create an IPC server. - - Args: - runtime: AgentRuntime instance, or None to create a default one - behaviors: Dictionary of agent_id -> AgentBehavior instances - default_behavior: Default behavior to use for agents not in behaviors dict - host: Host address to bind to - port: Port to listen on - - Returns: - Configured IPCServer instance - - Example: - from agent_runtime import create_local_llm_behavior - - # Create a default LLM behavior for all agents - default = create_local_llm_behavior( - model_path="models/mistral-7b.gguf", - system_prompt="You are a foraging agent." - ) - - server = create_server(default_behavior=default) - """ - if runtime is None: - runtime = AgentRuntime(max_workers=4) - - return IPCServer( - runtime=runtime, - behaviors=behaviors, - default_behavior=default_behavior, - host=host, - port=port, - ) diff --git a/python/sdk/agent_arena_sdk/__init__.py b/python/sdk/agent_arena_sdk/__init__.py index 7f5e806..b689ab6 100644 --- a/python/sdk/agent_arena_sdk/__init__.py +++ b/python/sdk/agent_arena_sdk/__init__.py @@ -23,6 +23,7 @@ def decide(obs: Observation) -> Decision: For complete examples, see the starter templates in the AgentArena repository. """ +from .adapters import FrameworkAdapter from .arena import AgentArena from .schemas import ( Decision, @@ -44,6 +45,8 @@ def decide(obs: Observation) -> Decision: __all__ = [ # Main API "AgentArena", + # Framework adapters + "FrameworkAdapter", # Core schemas "Observation", "Decision", diff --git a/python/sdk/agent_arena_sdk/adapters/__init__.py b/python/sdk/agent_arena_sdk/adapters/__init__.py new file mode 100644 index 0000000..01102a6 --- /dev/null +++ b/python/sdk/agent_arena_sdk/adapters/__init__.py @@ -0,0 +1,10 @@ +""" +Framework adapters for Agent Arena SDK. + +Adapters provide a structured interface for integrating LLM frameworks +(Anthropic, LangGraph, OpenAI, etc.) with Agent Arena. +""" + +from .base import FrameworkAdapter + +__all__ = ["FrameworkAdapter"] diff --git a/python/sdk/agent_arena_sdk/adapters/base.py b/python/sdk/agent_arena_sdk/adapters/base.py new file mode 100644 index 0000000..d29e3d4 --- /dev/null +++ b/python/sdk/agent_arena_sdk/adapters/base.py @@ -0,0 +1,289 @@ +""" +Base adapter for LLM framework integrations. + +Provides shared utilities that all framework adapters need: +- Observation formatting (game state -> text for LLM prompts) +- Canonical action tool definitions +- Fallback decision logic when the LLM fails +""" + +from abc import ABC, abstractmethod + +from ..schemas import Decision, Observation, ToolSchema + + +class FrameworkAdapter(ABC): + """ + Base class for framework-specific adapters. + + Subclasses implement ``decide()`` with their framework's LLM client. + The base class provides shared utilities so each adapter doesn't + duplicate observation formatting or tool definitions. + + Example:: + + class MyAdapter(FrameworkAdapter): + def decide(self, obs: Observation) -> Decision: + context = self.format_observation(obs) + tools = self.get_action_tools() + # ... call your LLM with context and tools ... + return Decision(tool="move_to", params={...}) + + arena = AgentArena() + arena.run(MyAdapter()) + """ + + @abstractmethod + def decide(self, obs: Observation) -> Decision: + """Make a decision given an observation. Subclasses must implement.""" + ... + + def format_observation(self, obs: Observation) -> str: + """ + Format an observation into human-readable text for an LLM prompt. + + Includes: position, health, energy, nearby resources/hazards/stations, + inventory, exploration status, objective progress, and last tool result. + + Override to customize how observations are presented to your LLM. + """ + lines: list[str] = [] + + # Header + lines.append( + f"Tick: {obs.tick} | Position: ({obs.position[0]:.1f}, " + f"{obs.position[1]:.1f}, {obs.position[2]:.1f}) | " + f"Health: {obs.health:.0f} | Energy: {obs.energy:.0f}" + ) + + # Perception + perception = getattr(obs, "perception_radius", 50.0) + exploration_pct = 0.0 + if obs.exploration: + exploration_pct = obs.exploration.exploration_percentage + lines.append(f"Perception: {perception:.0f} units | Explored: {exploration_pct:.1f}%") + lines.append("") + + # Resources (top 5) + if obs.nearby_resources: + summaries = [ + f"{r.name} ({r.type}) dist={r.distance:.1f} pos={list(r.position)}" + for r in obs.nearby_resources[:5] + ] + lines.append(f"Resources: {'; '.join(summaries)}") + else: + lines.append("Resources: None") + + # Hazards (top 5) + if obs.nearby_hazards: + summaries = [ + f"{h.name} ({h.type}) dist={h.distance:.1f} pos={list(h.position)}" + for h in obs.nearby_hazards[:5] + ] + lines.append(f"Hazards: {'; '.join(summaries)}") + else: + lines.append("Hazards: None") + + # Stations (top 5) + if obs.nearby_stations: + summaries = [ + f"{s.name} ({s.type}) dist={s.distance:.1f} pos={list(s.position)}" + for s in obs.nearby_stations[:5] + ] + lines.append(f"Stations: {'; '.join(summaries)}") + else: + lines.append("Stations: None") + + # Inventory (handles both dict format from custom and ItemInfo list) + raw_inventory = obs.custom.get("inventory", {}) if obs.custom else {} + if raw_inventory: + inv_str = ", ".join(f"{k}: {v}" for k, v in raw_inventory.items()) + lines.append(f"Inventory: {inv_str}") + elif obs.inventory: + inv_str = ", ".join(f"{item.name} x{item.quantity}" for item in obs.inventory) + lines.append(f"Inventory: {inv_str}") + else: + lines.append("Inventory: Empty") + + lines.append("") + + # Exploration targets (top 4) + if obs.exploration and obs.exploration.explore_targets: + targets = [ + f"{t.direction} pos={list(t.position)} ({t.distance:.1f}u away)" + for t in obs.exploration.explore_targets[:4] + ] + lines.append(f"Exploration targets: {'; '.join(targets)}") + else: + lines.append("Exploration targets: None") + + # Exploration hint when no resources visible + if not obs.nearby_resources: + if obs.exploration and obs.exploration.explore_targets: + best = obs.exploration.explore_targets[0] + pos = list(best.position) + lines.append( + f"No resources visible! Use explore or move_to an " + f"exploration target to find them. Nearest: {pos}" + ) + else: + lines.append("No resources visible! Move to an unexplored area to find them.") + + # Objective + if obs.objective: + lines.append("") + lines.append(f"Objective: {obs.objective.description}") + if obs.current_progress: + progress_parts = [] + for metric, value in obs.current_progress.items(): + target = "" + if obs.objective.success_metrics and metric in obs.objective.success_metrics: + target = f"/{obs.objective.success_metrics[metric].target:.0f}" + progress_parts.append(f"{metric}: {value:.0f}{target}") + lines.append(f"Progress: {', '.join(progress_parts)}") + + # Last tool result + if obs.last_tool_result: + tr = obs.last_tool_result + status = "OK" if tr.success else f"FAILED: {tr.error}" + lines.append(f"Last action: {tr.tool} -> {status}") + + return "\n".join(lines) + + def get_action_tools(self) -> list[ToolSchema]: + """ + Return the canonical set of action tools. + + These are terminal tools — calling one ends the agent's turn. + Descriptions include "This ends your turn." so LLMs can + distinguish action tools from future query tools. + + Override to add scenario-specific tools or modify descriptions. + """ + return [ + ToolSchema( + name="move_to", + description=( + "Navigate to a target position. The game handles " + "pathfinding and obstacle avoidance. This ends your turn." + ), + parameters={ + "type": "object", + "properties": { + "target_position": { + "type": "array", + "items": {"type": "number"}, + "description": "Target position as [x, y, z]", + } + }, + "required": ["target_position"], + }, + ), + ToolSchema( + name="collect", + description=( + "Collect a nearby resource by name. Must be within " + "collection range. This ends your turn." + ), + parameters={ + "type": "object", + "properties": { + "target_name": { + "type": "string", + "description": "Name of the resource to collect", + } + }, + "required": ["target_name"], + }, + ), + ToolSchema( + name="craft_item", + description=( + "Craft an item at a nearby crafting station. Must be " + "within range of the correct station type. Recipes: " + "torch (1 wood + 1 stone at workbench), " + "meal (2 berry at workbench), " + "shelter (3 wood + 2 stone at anvil). " + "This ends your turn." + ), + parameters={ + "type": "object", + "properties": { + "recipe": { + "type": "string", + "description": "Recipe name (e.g., 'torch', 'shelter', 'meal')", + } + }, + "required": ["recipe"], + }, + ), + ToolSchema( + name="explore", + description=( + "Move toward the nearest unexplored area to discover " + "new resources. Use this when no resources are visible. " + "This ends your turn." + ), + parameters={"type": "object", "properties": {}}, + ), + ToolSchema( + name="idle", + description="Do nothing this tick. This ends your turn.", + parameters={"type": "object", "properties": {}}, + ), + ] + + def fallback_decision(self, obs: Observation) -> Decision: + """ + Make a sensible fallback decision from observation data. + + Used when the LLM fails to produce a valid tool call. Priority: + + 1. Flee from nearby hazards (within 3.0 units) + 2. Move toward the closest resource + 3. Move toward the nearest exploration target + 4. Move in +X direction as a last resort + + Override to customize fallback behavior. + """ + # Priority 1: Flee from nearby hazards + if obs.nearby_hazards: + closest = min(obs.nearby_hazards, key=lambda h: h.distance) + if closest.distance < 3.0: + hx, hy, hz = closest.position + px, py, pz = obs.position + dx, dz = px - hx, pz - hz + dist = max((dx**2 + dz**2) ** 0.5, 0.1) + flee_x = px + (dx / dist) * 5.0 + flee_z = pz + (dz / dist) * 5.0 + return Decision( + tool="move_to", + params={"target_position": [flee_x, py, flee_z]}, + reasoning=f"Fleeing hazard {closest.name} at dist {closest.distance:.1f}", + ) + + # Priority 2: Move toward closest resource + if obs.nearby_resources: + closest = min(obs.nearby_resources, key=lambda r: r.distance) + return Decision( + tool="move_to", + params={"target_position": list(closest.position)}, + reasoning=f"Moving toward {closest.name} at dist {closest.distance:.1f}", + ) + + # Priority 3: Move toward nearest exploration target + if obs.exploration and obs.exploration.explore_targets: + best = obs.exploration.explore_targets[0] + return Decision( + tool="move_to", + params={"target_position": list(best.position)}, + reasoning=f"Exploring {best.direction}", + ) + + # Priority 4: Move in +X direction + px, py, pz = obs.position + return Decision( + tool="move_to", + params={"target_position": [px + 10.0, py, pz]}, + reasoning="No resources or exploration data, moving to explore", + ) diff --git a/python/sdk/agent_arena_sdk/arena.py b/python/sdk/agent_arena_sdk/arena.py index 5a14c68..b31993b 100644 --- a/python/sdk/agent_arena_sdk/arena.py +++ b/python/sdk/agent_arena_sdk/arena.py @@ -4,8 +4,10 @@ This provides a simple API for learners to connect their agents to the game. """ +from __future__ import annotations + import logging -from typing import Callable +from typing import Any, Callable from .schemas import Decision, Observation from .server import MinimalIPCServer @@ -13,6 +15,27 @@ logger = logging.getLogger(__name__) +def _resolve_callback( + agent: Callable[[Observation], Decision] | Any, +) -> Callable[[Observation], Decision]: + """Resolve *agent* to a plain callback for the IPC server. + + Accepts: + - A ``FrameworkAdapter`` (or any object with a ``decide`` method) + - A bare callable ``(Observation) -> Decision`` + + Raises ``TypeError`` if *agent* is neither. + """ + if hasattr(agent, "decide") and callable(agent.decide): + return agent.decide # type: ignore[return-value] + if callable(agent): + return agent # type: ignore[return-value] + raise TypeError( + "agent must be a callable(Observation -> Decision) or an object " + "with a decide(Observation) -> Decision method" + ) + + class AgentArena: """ Connection manager for Agent Arena game. @@ -22,7 +45,8 @@ class AgentArena: - Calling your decide function each tick - Managing the connection to Godot - Example: + Example — bare callback:: + from agent_arena_sdk import AgentArena, Observation, Decision def my_decide(obs: Observation) -> Decision: @@ -36,6 +60,14 @@ def my_decide(obs: Observation) -> Decision: arena = AgentArena(host="127.0.0.1", port=5000) arena.run(my_decide) # Blocks until stopped + + Example — framework adapter:: + + from agent_arena_sdk import AgentArena + from my_adapter import MyAdapter + + arena = AgentArena() + arena.run(MyAdapter(model="claude-sonnet-4-20250514")) """ def __init__( @@ -63,7 +95,7 @@ def __init__( f"{' (debug enabled)' if enable_debug else ''}" ) - def run(self, decide_callback: Callable[[Observation], Decision]) -> None: + def run(self, agent: Callable[[Observation], Decision] | Any) -> None: """ Run the agent server (blocking). @@ -71,7 +103,9 @@ def run(self, decide_callback: Callable[[Observation], Decision]) -> None: Blocks until the server is stopped (Ctrl+C). Args: - decide_callback: Function that takes Observation and returns Decision + agent: A callable ``(Observation) -> Decision``, or an object + with a ``decide(Observation) -> Decision`` method (e.g. a + :class:`~agent_arena_sdk.adapters.FrameworkAdapter`). Example: def decide(obs: Observation) -> Decision: @@ -80,6 +114,8 @@ def decide(obs: Observation) -> Decision: arena = AgentArena() arena.run(decide) # Blocks here """ + decide_callback = _resolve_callback(agent) + logger.info("Starting agent server...") logger.info("Waiting for connection from Godot...") @@ -97,14 +133,15 @@ def decide(obs: Observation) -> Decision: finally: logger.info("Agent server shut down") - async def run_async(self, decide_callback: Callable[[Observation], Decision]) -> None: + async def run_async(self, agent: Callable[[Observation], Decision] | Any) -> None: """ Run the agent server (async). This is an async version of run() that can be awaited. Args: - decide_callback: Function that takes Observation and returns Decision + agent: A callable ``(Observation) -> Decision``, or an object + with a ``decide(Observation) -> Decision`` method. Example: async def main(): @@ -116,6 +153,8 @@ def decide(obs: Observation) -> Decision: asyncio.run(main()) """ + decide_callback = _resolve_callback(agent) + logger.info("Starting agent server (async)...") logger.info("Waiting for connection from Godot...") diff --git a/starters/README.md b/starters/README.md index c7647c7..29d76bd 100644 --- a/starters/README.md +++ b/starters/README.md @@ -14,7 +14,8 @@ Unlike frameworks that hide complexity behind base classes, starters give you wo |---------|-------------|----------| | [beginner/](beginner/) | Simple if/else logic, no memory | Learning the basics | | [intermediate/](intermediate/) | Memory, planning, state tracking | Building real skills | -| [llm/](llm/) | LLM-powered reasoning | Advanced techniques | +| [llm/](llm/) | Local LLM reasoning (llama.cpp) | Advanced techniques | +| [claude/](claude/) | Anthropic Claude with native tool_use | Learning framework integration | ## Quick Start diff --git a/starters/claude/README.md b/starters/claude/README.md new file mode 100644 index 0000000..208c267 --- /dev/null +++ b/starters/claude/README.md @@ -0,0 +1,190 @@ +# Claude Starter — Learn Anthropic Tool Use by Building a Game Agent + +This starter teaches you **Anthropic's Claude tool_use API** by building an AI agent that plays Agent Arena scenarios. + +## What You'll Learn + +- **Tool definitions** — how to describe actions as JSON Schema tools +- **Tool use responses** — how Claude calls tools with typed parameters +- **System prompts** — giving Claude personality, strategy, and constraints +- **Context injection** — formatting game state into effective prompts +- **Error handling** — graceful fallbacks when the LLM doesn't cooperate + +## Prerequisites + +1. An **Anthropic API key** — get one at [console.anthropic.com](https://console.anthropic.com) +2. Python 3.11+ +3. Agent Arena game (Godot) running + +## Quick Start + +```bash +# 1. Set your API key +export ANTHROPIC_API_KEY=sk-ant-... + +# 2. Install dependencies +pip install -r requirements.txt + +# 3. Start the agent +python run.py + +# 4. In Godot: open scenes/foraging.tscn → F5 → SPACE +``` + +Your agent will start making decisions using Claude! + +## Files + +| File | What it does | +|------|-------------| +| `agent.py` | `ClaudeAdapter` — formats observations, calls Claude, extracts tool calls | +| `run.py` | Entry point — parses args, creates adapter, starts server | +| `requirements.txt` | Dependencies (agent-arena-sdk, anthropic) | + +## How It Works + +Each game tick: + +``` +Godot sends Observation (what the agent sees) + ↓ +ClaudeAdapter.format_observation() → text context + ↓ +Claude reads context + tool definitions + ↓ +Claude calls a tool (e.g., move_to with target_position) + ↓ +ClaudeAdapter extracts tool call → Decision + ↓ +Decision sent back to Godot +``` + +### The Key Concept: Tool Use + +Instead of asking Claude to output JSON (fragile, needs parsing), we define **tools**: + +```python +# This is what gets sent to Claude as a tool definition: +{ + "name": "move_to", + "description": "Navigate to a target position. This ends your turn.", + "input_schema": { + "type": "object", + "properties": { + "target_position": { + "type": "array", + "items": {"type": "number"}, + "description": "Target position as [x, y, z]" + } + }, + "required": ["target_position"] + } +} +``` + +Claude responds with a structured tool call — no string parsing needed: + +```python +# Claude's response contains a tool_use block: +block.type == "tool_use" +block.name == "move_to" +block.input == {"target_position": [10.0, 0.0, 5.0]} +``` + +## Customization + +### Change the System Prompt + +Edit `SYSTEM_PROMPT` at the top of `agent.py`. Try: +- Adding personality ("You are a cautious agent that avoids all risk") +- Changing strategy ("Always explore before collecting") +- Adding domain knowledge ("Fire hazards deal 10 damage per tick") + +### Change the Model + +```bash +python run.py --model claude-haiku-4-5-20251001 # Fastest, cheapest +python run.py --model claude-sonnet-4-20250514 # Balanced (default) +python run.py --model claude-opus-4-20250514 # Most capable +``` + +### Add Memory + +The base adapter is stateless (each tick is independent). Add memory: + +```python +class ClaudeAdapterWithMemory(ClaudeAdapter): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.history = [] # Remember past observations + + def decide(self, obs): + self.history.append(obs) + # Include recent history in the prompt + return super().decide(obs) +``` + +### Override Observation Formatting + +```python +class MyAdapter(ClaudeAdapter): + def format_observation(self, obs): + # Your custom formatting + text = super().format_observation(obs) + text += "\n\nRemembered locations: ..." + return text +``` + +## Cost Estimation + +Each tick costs approximately: +- **Haiku**: ~0.1 cent (500 input + 100 output tokens) +- **Sonnet**: ~0.5 cent +- **Opus**: ~2.5 cents + +A typical foraging run (100 ticks) costs ~$0.10 with Sonnet. + +## Debugging + +### Enable Debug Viewer + +```bash +python run.py --debug +# Open http://127.0.0.1:5000/debug in your browser +``` + +### View Traces + +The adapter records each decision in `self.last_trace` with: +- System prompt sent +- Observation context sent +- Tokens used +- Parse method (tool_use, fallback, error) +- Final decision + +### Common Issues + +**"Claude did not call a tool"** — Claude sometimes returns text without calling a tool. The adapter falls back to observation-based logic. Try making the system prompt more directive. + +**High latency** — Each tick requires an API round-trip. Use Haiku for faster responses, or add caching for repeated observations. + +**"ANTHROPIC_API_KEY not set"** — Export your API key: `export ANTHROPIC_API_KEY=sk-ant-...` + +## Comparison with LLM Starter + +| Feature | LLM Starter | Claude Starter | +|---------|------------|---------------| +| LLM location | Local (llama.cpp) | Cloud (Anthropic API) | +| Output format | JSON text parsing | Native tool_use | +| GPU required | Yes | No | +| Cost | Free (after model download) | Per-token API cost | +| Latency | Low (local) | Medium (network) | +| Model quality | Varies (local models) | High (Claude) | +| Setup | Download model (~4GB) | Set API key | + +## Next Steps + +- Modify `SYSTEM_PROMPT` to improve decision quality +- Add memory to remember past observations +- Try different models and compare scores +- Read the [Anthropic docs](https://docs.anthropic.com/en/docs/build-with-claude/tool-use) to learn more about tool use diff --git a/starters/claude/agent.py b/starters/claude/agent.py new file mode 100644 index 0000000..8815018 --- /dev/null +++ b/starters/claude/agent.py @@ -0,0 +1,228 @@ +""" +Claude-Powered Agent — Using Anthropic's Claude with Native Tool Use + +This agent uses Claude's tool_use feature for structured decision making. +Instead of generating JSON text and parsing it (like the LLM starter), +Claude directly calls tools with typed parameters — no parsing errors. + +How it works: +1. Each tick, the game sends an Observation (what the agent sees) +2. We format that into text context for Claude +3. Claude reads the context and calls an action tool (move_to, collect, etc.) +4. We extract the tool call and return it as a Decision + +This is YOUR code — modify the system prompt, change the model, +add memory, or customize the observation formatting! + +Learn more about Anthropic tool use: + https://docs.anthropic.com/en/docs/build-with-claude/tool-use +""" + +import logging +import os + +from anthropic import Anthropic + +from agent_arena_sdk import Decision, Observation +from agent_arena_sdk.adapters import FrameworkAdapter + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# System prompt — edit this to change your agent's personality and strategy! +# --------------------------------------------------------------------------- +SYSTEM_PROMPT = """\ +You are an AI agent in a 3D simulation world (50m x 50m). +Your goal is to collect resources, avoid hazards, and complete objectives. + +IMPORTANT CONSTRAINTS: +- You have LIMITED VISIBILITY (~10 unit radius). You can only see nearby things. +- World boundaries are roughly -25 to +25 on both X and Z axes. +- Each tool call is ONE action per tick. Choose wisely. + +STRATEGY: +- If a hazard is within 3 units, move away immediately (survival first). +- If resources are nearby, move toward the closest one to collect it. +- If you have crafting materials and are near a station, craft items. +- If nothing is visible, explore to discover new areas. +- Read the objective and prioritize actions that make progress. + +Crafting recipes (must be near the correct station): +- torch: 1 wood + 1 stone (workbench) +- meal: 2 berry (workbench) +- shelter: 3 wood + 2 stone (anvil) + +Use the provided tools to take actions. Each tool call ends your turn.\ +""" + + +class ClaudeAdapter(FrameworkAdapter): + """ + Anthropic Claude adapter using native tool_use. + + Per tick this adapter: + 1. Formats the observation into text context + 2. Sends it to Claude with action tool definitions + 3. Claude calls a tool → we return that as the Decision + + Customise: + - ``SYSTEM_PROMPT`` (module-level) for personality / strategy + - ``model`` for different Claude models + - ``max_tokens`` for response length budget + """ + + def __init__( + self, + model: str = "claude-sonnet-4-20250514", + max_tokens: int = 1024, + api_key: str | None = None, + ): + """ + Initialise the Claude adapter. + + Args: + model: Anthropic model ID. Good options: + - claude-sonnet-4-20250514 (fast, cheap — good default) + - claude-opus-4-20250514 (most capable, slower) + - claude-haiku-4-5-20251001 (fastest, cheapest) + max_tokens: Maximum tokens for Claude's response. + api_key: Anthropic API key. If None, reads ANTHROPIC_API_KEY + from environment. + """ + self.model = model + self.max_tokens = max_tokens + self.client = Anthropic( + api_key=api_key or os.environ.get("ANTHROPIC_API_KEY") + ) + self.system_prompt = SYSTEM_PROMPT + + # Chain-of-thought trace for the debug viewer. + # The SDK's debug system reads this via adapter.last_trace. + self.last_trace: dict | None = None + + def decide(self, obs: Observation) -> Decision: + """ + Make a decision using Claude's tool_use. + + 1. Format observation → text context + 2. Call Claude with tools + 3. Extract tool_use block → Decision + 4. Fall back to observation-based logic on any failure + """ + # --- Build prompt ------------------------------------------------- + obs_text = self.format_observation(obs) + + # Convert our ToolSchema objects to Anthropic's format: + # {"name": ..., "description": ..., "input_schema": {...}} + tools = [t.to_anthropic_format() for t in self.get_action_tools()] + + # Trace dict for the debug viewer + trace: dict = { + "system_prompt": self.system_prompt, + "user_prompt": obs_text, + "llm_raw_output": None, + "tokens_used": 0, + "finish_reason": None, + "parse_method": None, + "decision": None, + } + + # --- Call Claude -------------------------------------------------- + try: + response = self.client.messages.create( + model=self.model, + max_tokens=self.max_tokens, + system=self.system_prompt, + messages=[{"role": "user", "content": obs_text}], + tools=tools, + ) + + trace["tokens_used"] = ( + response.usage.input_tokens + response.usage.output_tokens + ) + trace["finish_reason"] = response.stop_reason + + # --- Extract tool call ---------------------------------------- + # Claude returns content blocks. We look for a tool_use block. + for block in response.content: + if block.type == "tool_use": + tool_name = block.name + tool_input = block.input + + # The "explore" tool is synthetic — Claude calls it, but + # the game only understands move_to. We translate here. + if tool_name == "explore": + decision = self._resolve_explore(obs) + else: + decision = Decision( + tool=tool_name, + params=tool_input, + reasoning=self._extract_reasoning(response), + ) + + trace["parse_method"] = "tool_use" + trace["decision"] = { + "tool": decision.tool, + "params": decision.params, + "reasoning": decision.reasoning, + } + self.last_trace = trace + return decision + + # Claude returned text but no tool call + trace["llm_raw_output"] = self._extract_all_text(response) + trace["parse_method"] = "fallback_no_tool_use" + logger.warning("Claude did not call a tool, using fallback") + + except Exception as e: + logger.error(f"Anthropic API error: {e}") + trace["parse_method"] = "error" + + # --- Fallback ----------------------------------------------------- + decision = self.fallback_decision(obs) + trace["decision"] = { + "tool": decision.tool, + "params": decision.params, + "reasoning": decision.reasoning, + } + self.last_trace = trace + return decision + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _resolve_explore(self, obs: Observation) -> Decision: + """Translate the synthetic 'explore' tool into a concrete move_to.""" + if obs.exploration and obs.exploration.explore_targets: + target = obs.exploration.explore_targets[0] + return Decision( + tool="move_to", + params={"target_position": list(target.position)}, + reasoning=f"Exploring {target.direction}", + ) + # No exploration data — move in +X as a heuristic + px, py, pz = obs.position + return Decision( + tool="move_to", + params={"target_position": [px + 10.0, py, pz]}, + reasoning="Exploring (no targets available)", + ) + + @staticmethod + def _extract_reasoning(response) -> str: + """Pull text blocks from the response to use as reasoning.""" + texts = [ + block.text + for block in response.content + if block.type == "text" + ] + return " ".join(texts) if texts else "Claude decision" + + @staticmethod + def _extract_all_text(response) -> str: + """Concatenate all text content blocks.""" + return " ".join( + block.text for block in response.content if block.type == "text" + ) diff --git a/starters/claude/requirements.txt b/starters/claude/requirements.txt new file mode 100644 index 0000000..a851627 --- /dev/null +++ b/starters/claude/requirements.txt @@ -0,0 +1,3 @@ +# Claude Agent Requirements +agent-arena-sdk>=0.1.0 +anthropic>=0.39.0 diff --git a/starters/claude/run.py b/starters/claude/run.py new file mode 100644 index 0000000..bb23c7d --- /dev/null +++ b/starters/claude/run.py @@ -0,0 +1,73 @@ +""" +Claude Agent Runner + +Starts a Claude-powered agent that connects to the Agent Arena game. + +Prerequisites: + export ANTHROPIC_API_KEY=sk-ant-... # Your Anthropic API key + pip install -r requirements.txt + +Usage: + python run.py # Default (Sonnet) + python run.py --model claude-haiku-4-5-20251001 # Fastest / cheapest + python run.py --debug # Enable debug viewer +""" + +import argparse +import logging + +from agent_arena_sdk import AgentArena + +from agent import ClaudeAdapter + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser(description="Run Claude-powered agent") + parser.add_argument( + "--model", + type=str, + default="claude-sonnet-4-20250514", + help="Anthropic model ID (default: claude-sonnet-4-20250514)", + ) + parser.add_argument( + "--port", + type=int, + default=5000, + help="Port to listen on (default: 5000)", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Enable /debug/* endpoints and web trace viewer", + ) + args = parser.parse_args() + + logger.info("=" * 60) + logger.info(" Claude Agent for Agent Arena") + logger.info(f" Model : {args.model}") + logger.info(f" Port : {args.port}") + logger.info(f" Debug : {'ON' if args.debug else 'OFF'}") + logger.info("=" * 60) + + adapter = ClaudeAdapter(model=args.model) + + arena = AgentArena( + host="127.0.0.1", + port=args.port, + enable_debug=args.debug, + ) + + try: + arena.run(adapter) + except KeyboardInterrupt: + logger.info("\nAgent stopped by user") + + +if __name__ == "__main__": + main() diff --git a/starters/llm/llm_client.py b/starters/llm/llm_client.py index 166548a..2367c0d 100644 --- a/starters/llm/llm_client.py +++ b/starters/llm/llm_client.py @@ -1,33 +1,29 @@ """ LLM Client - Interface to Local Language Models -This module provides a simple interface to use local LLMs with the model manager. +This module provides a simple interface to use local LLMs via llama-cpp-python. You can see exactly how it works and modify it! Supports: - llama.cpp backend (GGUF models) -- vLLM backend (for high-performance inference) - Automatic tool calling + +Requirements: + pip install llama-cpp-python """ import json import logging -from pathlib import Path -import sys - -# Add parent directories to path for imports -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "python")) +from typing import Any, cast -from backends.base import BackendConfig -from backends.llama_cpp_backend import LlamaCppBackend -from agent_runtime.schemas import ToolSchema +from agent_arena_sdk import ToolSchema logger = logging.getLogger(__name__) class LLMClient: """ - Simple LLM client using local models. + Simple LLM client using local models via llama-cpp-python. This client: - Uses models managed by the model manager @@ -52,35 +48,53 @@ def __init__( temperature: float = 0.7, max_tokens: int = 512, n_gpu_layers: int = -1, # -1 = all layers on GPU + top_p: float = 0.9, + top_k: int = 40, ): """ Initialize LLM client. Args: - model_path: Path to model file (relative to project root) + model_path: Path to GGUF model file temperature: Sampling temperature (0-1, higher = more creative) max_tokens: Maximum tokens to generate n_gpu_layers: Number of layers on GPU (-1 = all, 0 = CPU only) + top_p: Top-p sampling parameter + top_k: Top-k sampling parameter """ + self.model_path = model_path self.temperature = temperature self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.llm = None - # Create backend config - config = BackendConfig( - model_path=model_path, - temperature=temperature, - max_tokens=max_tokens, - n_gpu_layers=n_gpu_layers, - ) - - # Initialize backend - logger.info(f"Loading model from {model_path}") - self.backend = LlamaCppBackend(config) + try: + from llama_cpp import Llama + + logger.info(f"Loading model from {model_path}") + + if n_gpu_layers == -1: + logger.info("Offloading all layers to GPU") + elif n_gpu_layers > 0: + logger.info(f"Offloading {n_gpu_layers} layers to GPU") + else: + logger.info("Using CPU only (no GPU offload)") + + self.llm = Llama( + model_path=model_path, + n_ctx=4096, + n_threads=8, + n_gpu_layers=n_gpu_layers, + verbose=False, + ) - if not self.backend.is_available(): - raise RuntimeError(f"Failed to load model from {model_path}") + logger.info("Model loaded successfully") - logger.info("Model loaded successfully") + except ImportError: + raise RuntimeError( + "llama-cpp-python not installed. Install with: pip install llama-cpp-python" + ) def generate( self, @@ -105,23 +119,38 @@ def generate( - tokens_used: Number of tokens generated - finish_reason: Why generation stopped """ + if not self.llm: + raise RuntimeError("Model not loaded") + try: - result = self.backend.generate( - prompt=prompt, + messages: list[dict[str, str]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = self.llm.create_chat_completion( + messages=messages, temperature=temperature or self.temperature, - system_prompt=system_prompt, + max_tokens=self.max_tokens, + top_p=self.top_p, + top_k=self.top_k, ) + resp = cast(dict[str, Any], response) + text = resp["choices"][0]["message"]["content"] or "" + tokens_used = resp["usage"]["total_tokens"] + finish_reason = str(resp["choices"][0].get("finish_reason", "stop")) + # Parse tool calls if present tool_call = None - if tools and result.text: - tool_call = self._parse_tool_call(result.text) + if tools and text: + tool_call = self._parse_tool_call(text) return { - "text": result.text, + "text": text, "tool_call": tool_call, - "tokens_used": result.tokens_used, - "finish_reason": result.finish_reason, + "tokens_used": tokens_used, + "finish_reason": finish_reason, } except Exception as e: @@ -166,9 +195,11 @@ def _parse_tool_call(self, text: str) -> dict | None: def is_available(self) -> bool: """Check if the LLM backend is ready.""" - return self.backend.is_available() + return self.llm is not None def unload(self) -> None: """Unload the model and free resources.""" - self.backend.unload() - logger.info("Model unloaded") + if self.llm: + del self.llm + self.llm = None + logger.info("Model unloaded") diff --git a/tests/test_adapter_base.py b/tests/test_adapter_base.py new file mode 100644 index 0000000..3e9d997 --- /dev/null +++ b/tests/test_adapter_base.py @@ -0,0 +1,448 @@ +"""Tests for FrameworkAdapter base class.""" + +import pytest + +from agent_arena_sdk import Decision, Observation, ToolSchema +from agent_arena_sdk.adapters.base import FrameworkAdapter +from agent_arena_sdk.schemas.observation import ( + ExplorationInfo, + ExploreTarget, + HazardInfo, + ItemInfo, + ResourceInfo, + StationInfo, + ToolResult, +) +from agent_arena_sdk.schemas.objective import MetricDefinition, Objective + + +class ConcreteAdapter(FrameworkAdapter): + """Minimal concrete adapter for testing base class methods.""" + + def decide(self, obs: Observation) -> Decision: + return Decision.idle("test") + + +def _make_obs(**kwargs) -> Observation: + """Helper to create test observations with defaults.""" + defaults = {"agent_id": "test", "tick": 1, "position": (0.0, 0.0, 0.0)} + defaults.update(kwargs) + return Observation(**defaults) + + +# --------------------------------------------------------------------------- +# FrameworkAdapter ABC +# --------------------------------------------------------------------------- + + +class TestFrameworkAdapterABC: + def test_cannot_instantiate_directly(self): + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + FrameworkAdapter() # type: ignore[abstract] + + def test_requires_decide_implementation(self): + class IncompleteAdapter(FrameworkAdapter): + pass + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteAdapter() # type: ignore[abstract] + + def test_concrete_adapter_works(self): + adapter = ConcreteAdapter() + obs = _make_obs() + decision = adapter.decide(obs) + assert decision.tool == "idle" + + +# --------------------------------------------------------------------------- +# format_observation +# --------------------------------------------------------------------------- + + +class TestFormatObservation: + def test_minimal_observation(self): + adapter = ConcreteAdapter() + text = adapter.format_observation(_make_obs()) + assert "Tick: 1" in text + assert "Position:" in text + assert "Health: 100" in text + assert "Energy: 100" in text + assert "Resources: None" in text + assert "Hazards: None" in text + assert "Stations: None" in text + assert "Inventory: Empty" in text + + def test_with_resources(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_resources=[ + ResourceInfo( + name="berry_1", type="berry", position=(5.0, 0.0, 3.0), distance=5.8 + ), + ResourceInfo( + name="wood_1", type="wood", position=(8.0, 0.0, 1.0), distance=8.1 + ), + ] + ) + text = adapter.format_observation(obs) + assert "berry_1" in text + assert "berry" in text + assert "5.8" in text + assert "wood_1" in text + + def test_resources_limited_to_5(self): + adapter = ConcreteAdapter() + resources = [ + ResourceInfo( + name=f"r_{i}", type="berry", position=(float(i), 0.0, 0.0), distance=float(i) + ) + for i in range(10) + ] + obs = _make_obs(nearby_resources=resources) + text = adapter.format_observation(obs) + assert "r_4" in text + assert "r_5" not in text + + def test_with_hazards(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_hazards=[ + HazardInfo( + name="fire_1", type="fire", position=(2.0, 0.0, 1.0), distance=2.2 + ) + ] + ) + text = adapter.format_observation(obs) + assert "fire_1" in text + assert "fire" in text + assert "2.2" in text + + def test_with_stations(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_stations=[ + StationInfo( + name="bench_1", + type="workbench", + position=(3.0, 0.0, 4.0), + distance=5.0, + ) + ] + ) + text = adapter.format_observation(obs) + assert "bench_1" in text + assert "workbench" in text + + def test_inventory_dict_format(self): + adapter = ConcreteAdapter() + obs = _make_obs(custom={"inventory": {"wood": 3, "stone": 1}}) + text = adapter.format_observation(obs) + assert "wood" in text + assert "3" in text + assert "stone" in text + + def test_inventory_iteminfo_format(self): + adapter = ConcreteAdapter() + obs = _make_obs( + inventory=[ItemInfo(id="i1", name="torch", quantity=2)] + ) + text = adapter.format_observation(obs) + assert "torch" in text + assert "x2" in text + + def test_exploration_targets(self): + adapter = ConcreteAdapter() + obs = _make_obs( + exploration=ExplorationInfo( + exploration_percentage=45.0, + total_cells=100, + seen_cells=45, + frontiers_by_direction={"north": 10.0}, + explore_targets=[ + ExploreTarget( + direction="north", distance=10.0, position=(0.0, 0.0, 10.0) + ) + ], + ) + ) + text = adapter.format_observation(obs) + assert "45.0%" in text + assert "north" in text + + def test_exploration_hint_when_no_resources(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_resources=[], + exploration=ExplorationInfo( + exploration_percentage=30.0, + total_cells=100, + seen_cells=30, + frontiers_by_direction={}, + explore_targets=[ + ExploreTarget( + direction="east", distance=12.0, position=(12.0, 0.0, 0.0) + ) + ], + ), + ) + text = adapter.format_observation(obs) + assert "No resources visible" in text + assert "12.0" in text + + def test_exploration_hint_no_targets(self): + adapter = ConcreteAdapter() + obs = _make_obs(nearby_resources=[]) + text = adapter.format_observation(obs) + assert "No resources visible" in text + assert "unexplored area" in text + + def test_no_hint_when_resources_present(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_resources=[ + ResourceInfo( + name="r1", type="berry", position=(5.0, 0.0, 0.0), distance=5.0 + ) + ] + ) + text = adapter.format_observation(obs) + assert "No resources visible" not in text + + def test_objective_included(self): + adapter = ConcreteAdapter() + obs = _make_obs( + objective=Objective( + description="Collect 10 resources", + success_metrics={"resources_collected": MetricDefinition(target=10.0)}, + ), + current_progress={"resources_collected": 4.0}, + ) + text = adapter.format_observation(obs) + assert "Collect 10 resources" in text + assert "resources_collected" in text + assert "4" in text + + def test_last_tool_result_success(self): + adapter = ConcreteAdapter() + obs = _make_obs( + last_tool_result=ToolResult(tool="move_to", success=True) + ) + text = adapter.format_observation(obs) + assert "move_to" in text + assert "OK" in text + + def test_last_tool_result_failure(self): + adapter = ConcreteAdapter() + obs = _make_obs( + last_tool_result=ToolResult( + tool="collect", success=False, error="OUT_OF_RANGE" + ) + ) + text = adapter.format_observation(obs) + assert "collect" in text + assert "FAILED" in text + assert "OUT_OF_RANGE" in text + + +# --------------------------------------------------------------------------- +# get_action_tools +# --------------------------------------------------------------------------- + + +class TestGetActionTools: + def test_returns_tool_schemas(self): + adapter = ConcreteAdapter() + tools = adapter.get_action_tools() + assert isinstance(tools, list) + assert all(isinstance(t, ToolSchema) for t in tools) + + def test_canonical_tool_names(self): + adapter = ConcreteAdapter() + names = {t.name for t in adapter.get_action_tools()} + assert names == {"move_to", "collect", "craft_item", "explore", "idle"} + + def test_descriptions_say_ends_turn(self): + adapter = ConcreteAdapter() + for tool in adapter.get_action_tools(): + assert "ends your turn" in tool.description.lower(), ( + f"Tool {tool.name} description missing 'ends your turn'" + ) + + def test_anthropic_format_conversion(self): + adapter = ConcreteAdapter() + for tool in adapter.get_action_tools(): + fmt = tool.to_anthropic_format() + assert "name" in fmt + assert "description" in fmt + assert "input_schema" in fmt + assert fmt["name"] == tool.name + + def test_openai_format_conversion(self): + adapter = ConcreteAdapter() + for tool in adapter.get_action_tools(): + fmt = tool.to_openai_format() + assert fmt["type"] == "function" + assert fmt["function"]["name"] == tool.name + + def test_move_to_has_target_position(self): + adapter = ConcreteAdapter() + tools = {t.name: t for t in adapter.get_action_tools()} + move_to = tools["move_to"] + assert "target_position" in move_to.parameters["properties"] + assert "target_position" in move_to.parameters["required"] + + def test_collect_has_target_name(self): + adapter = ConcreteAdapter() + tools = {t.name: t for t in adapter.get_action_tools()} + collect = tools["collect"] + assert "target_name" in collect.parameters["properties"] + + def test_craft_item_has_recipe(self): + adapter = ConcreteAdapter() + tools = {t.name: t for t in adapter.get_action_tools()} + craft = tools["craft_item"] + assert "recipe" in craft.parameters["properties"] + + +# --------------------------------------------------------------------------- +# fallback_decision +# --------------------------------------------------------------------------- + + +class TestFallbackDecision: + def test_flee_hazard(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_hazards=[ + HazardInfo( + name="fire", type="fire", position=(1.0, 0.0, 0.0), distance=1.0 + ) + ] + ) + decision = adapter.fallback_decision(obs) + assert decision.tool == "move_to" + # Should move away from hazard at x=1 (i.e. negative X direction) + assert decision.params["target_position"][0] < 0 + + def test_hazard_far_away_not_fleeing(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_hazards=[ + HazardInfo( + name="fire", type="fire", position=(10.0, 0.0, 0.0), distance=10.0 + ) + ] + ) + decision = adapter.fallback_decision(obs) + # Hazard at 10.0 > 3.0 threshold, should not flee + assert decision.tool == "move_to" + # With no resources, goes to +X direction + assert decision.params["target_position"][0] > 0 + + def test_collect_nearest_resource(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_resources=[ + ResourceInfo( + name="berry", type="berry", position=(5.0, 0.0, 3.0), distance=5.8 + ), + ResourceInfo( + name="wood", type="wood", position=(2.0, 0.0, 1.0), distance=2.2 + ), + ] + ) + decision = adapter.fallback_decision(obs) + assert decision.tool == "move_to" + # Should pick wood (closer at 2.2) + assert decision.params["target_position"] == [2.0, 0.0, 1.0] + + def test_explore_when_nothing_visible(self): + adapter = ConcreteAdapter() + obs = _make_obs( + exploration=ExplorationInfo( + exploration_percentage=20.0, + total_cells=100, + seen_cells=20, + frontiers_by_direction={}, + explore_targets=[ + ExploreTarget( + direction="east", distance=15.0, position=(15.0, 0.0, 0.0) + ) + ], + ) + ) + decision = adapter.fallback_decision(obs) + assert decision.tool == "move_to" + assert decision.params["target_position"] == [15.0, 0.0, 0.0] + + def test_default_move_when_no_data(self): + adapter = ConcreteAdapter() + obs = _make_obs() + decision = adapter.fallback_decision(obs) + assert decision.tool == "move_to" + # Moves +10 in X from position (0, 0, 0) + assert decision.params["target_position"][0] == pytest.approx(10.0) + + def test_hazard_priority_over_resources(self): + adapter = ConcreteAdapter() + obs = _make_obs( + nearby_hazards=[ + HazardInfo( + name="fire", type="fire", position=(1.0, 0.0, 0.0), distance=1.0 + ) + ], + nearby_resources=[ + ResourceInfo( + name="berry", type="berry", position=(5.0, 0.0, 0.0), distance=5.0 + ) + ], + ) + decision = adapter.fallback_decision(obs) + # Should flee (hazard at 1.0 < 3.0), not go toward resource + assert decision.params["target_position"][0] < 0 + + +# --------------------------------------------------------------------------- +# AgentArena.run() duck-typing +# --------------------------------------------------------------------------- + + +class TestResolveCallback: + def test_callable_accepted(self): + from agent_arena_sdk.arena import _resolve_callback + + def my_decide(obs: Observation) -> Decision: + return Decision.idle() + + cb = _resolve_callback(my_decide) + assert cb is my_decide + + def test_adapter_accepted(self): + from agent_arena_sdk.arena import _resolve_callback + + adapter = ConcreteAdapter() + cb = _resolve_callback(adapter) + assert cb == adapter.decide + + def test_object_with_decide_accepted(self): + from agent_arena_sdk.arena import _resolve_callback + + class PlainAgent: + def decide(self, obs: Observation) -> Decision: + return Decision.idle() + + agent = PlainAgent() + cb = _resolve_callback(agent) + assert cb == agent.decide + + def test_non_callable_rejected(self): + from agent_arena_sdk.arena import _resolve_callback + + with pytest.raises(TypeError, match="callable"): + _resolve_callback(42) + + def test_string_rejected(self): + from agent_arena_sdk.arena import _resolve_callback + + with pytest.raises(TypeError, match="callable"): + _resolve_callback("not a callback") diff --git a/tests/test_anthropic_adapter.py b/tests/test_anthropic_adapter.py new file mode 100644 index 0000000..cab2450 --- /dev/null +++ b/tests/test_anthropic_adapter.py @@ -0,0 +1,345 @@ +"""Tests for the Anthropic/Claude adapter (starters/claude/agent.py). + +All Anthropic API calls are mocked — no real API key needed. +The ``anthropic`` package is not required to be installed. +""" + +import importlib +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from agent_arena_sdk import Decision, Observation +from agent_arena_sdk.schemas.observation import ( + ExplorationInfo, + ExploreTarget, + HazardInfo, + ResourceInfo, +) + +# --------------------------------------------------------------------------- +# Mock the ``anthropic`` package so the starter can be imported without it +# being installed. We create a fake module with a MagicMock ``Anthropic`` +# class. +# --------------------------------------------------------------------------- +_anthropic_mock = types.ModuleType("anthropic") +_anthropic_mock.Anthropic = MagicMock # type: ignore[attr-defined] +sys.modules.setdefault("anthropic", _anthropic_mock) + +# The Claude starter lives outside the installed package. +_CLAUDE_STARTER = str(Path(__file__).resolve().parent.parent / "starters" / "claude") +if _CLAUDE_STARTER not in sys.path: + sys.path.insert(0, _CLAUDE_STARTER) + +# Force (re-)import so it picks up the mock. +if "agent" in sys.modules: + importlib.reload(sys.modules["agent"]) +from agent import ClaudeAdapter # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_obs(**kwargs) -> Observation: + defaults = {"agent_id": "test", "tick": 1, "position": (5.0, 0.0, 3.0)} + defaults.update(kwargs) + return Observation(**defaults) + + +def _mock_tool_use_response( + tool_name: str = "move_to", + tool_input: dict | None = None, + text: str = "I should move toward the berry.", +) -> MagicMock: + """Build a mock Anthropic response containing a tool_use block.""" + if tool_input is None: + tool_input = {"target_position": [10.0, 0.0, 5.0]} + + text_block = MagicMock() + text_block.type = "text" + text_block.text = text + + tool_block = MagicMock() + tool_block.type = "tool_use" + tool_block.name = tool_name + tool_block.input = tool_input + + response = MagicMock() + response.content = [text_block, tool_block] + response.stop_reason = "tool_use" + response.usage.input_tokens = 120 + response.usage.output_tokens = 45 + return response + + +def _mock_text_only_response(text: str = "Let me think about this...") -> MagicMock: + """Build a mock Anthropic response with NO tool_use block.""" + text_block = MagicMock() + text_block.type = "text" + text_block.text = text + + response = MagicMock() + response.content = [text_block] + response.stop_reason = "end_turn" + response.usage.input_tokens = 80 + response.usage.output_tokens = 30 + return response + + +def _make_adapter(mock_client: MagicMock | None = None) -> ClaudeAdapter: + """Create a ClaudeAdapter with a mocked Anthropic client.""" + adapter = ClaudeAdapter(api_key="test-key") + if mock_client is not None: + adapter.client = mock_client + return adapter + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestClaudeAdapterDecide: + def test_tool_use_move_to(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response( + tool_name="move_to", + tool_input={"target_position": [10.0, 0.0, 5.0]}, + ) + + adapter = _make_adapter(mock_client) + decision = adapter.decide(_make_obs()) + + assert decision.tool == "move_to" + assert decision.params == {"target_position": [10.0, 0.0, 5.0]} + assert isinstance(decision.reasoning, str) + + def test_tool_use_collect(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response( + tool_name="collect", + tool_input={"target_name": "berry_001"}, + text="Collecting the nearby berry.", + ) + + adapter = _make_adapter(mock_client) + decision = adapter.decide(_make_obs()) + + assert decision.tool == "collect" + assert decision.params == {"target_name": "berry_001"} + + def test_tool_use_craft_item(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response( + tool_name="craft_item", + tool_input={"recipe": "torch"}, + ) + + adapter = _make_adapter(mock_client) + decision = adapter.decide(_make_obs()) + + assert decision.tool == "craft_item" + assert decision.params == {"recipe": "torch"} + + def test_tool_use_idle(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response( + tool_name="idle", + tool_input={}, + ) + + adapter = _make_adapter(mock_client) + decision = adapter.decide(_make_obs()) + + assert decision.tool == "idle" + + +class TestExploreTranslation: + def test_explore_with_targets(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response( + tool_name="explore", + tool_input={}, + ) + + adapter = _make_adapter(mock_client) + obs = _make_obs( + exploration=ExplorationInfo( + exploration_percentage=25.0, + total_cells=100, + seen_cells=25, + frontiers_by_direction={}, + explore_targets=[ + ExploreTarget( + direction="north", + distance=12.0, + position=(0.0, 0.0, 12.0), + ) + ], + ) + ) + decision = adapter.decide(obs) + + # explore is translated to move_to + assert decision.tool == "move_to" + assert decision.params["target_position"] == [0.0, 0.0, 12.0] + assert "north" in (decision.reasoning or "").lower() + + def test_explore_without_targets(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response( + tool_name="explore", + tool_input={}, + ) + + adapter = _make_adapter(mock_client) + obs = _make_obs(position=(5.0, 0.0, 3.0)) + decision = adapter.decide(obs) + + # Falls back to +10 in X + assert decision.tool == "move_to" + assert decision.params["target_position"][0] == pytest.approx(15.0) + + +class TestFallbacks: + def test_fallback_on_text_only_response(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_text_only_response() + + adapter = _make_adapter(mock_client) + obs = _make_obs( + nearby_resources=[ + ResourceInfo( + name="berry", type="berry", position=(8.0, 0.0, 4.0), distance=3.2 + ) + ] + ) + decision = adapter.decide(obs) + + # Fallback should go to nearest resource + assert decision.tool == "move_to" + assert decision.params["target_position"] == [8.0, 0.0, 4.0] + + def test_fallback_on_api_error(self): + mock_client = MagicMock() + mock_client.messages.create.side_effect = Exception("rate limit exceeded") + + adapter = _make_adapter(mock_client) + obs = _make_obs() + decision = adapter.decide(obs) + + # Should not raise, should return a valid decision + assert decision.tool in ("move_to", "idle") + assert isinstance(decision, Decision) + + def test_fallback_flees_hazard(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_text_only_response() + + adapter = _make_adapter(mock_client) + obs = _make_obs( + position=(0.0, 0.0, 0.0), + nearby_hazards=[ + HazardInfo( + name="fire", type="fire", position=(1.0, 0.0, 0.0), distance=1.0 + ) + ], + ) + decision = adapter.decide(obs) + + assert decision.tool == "move_to" + # Should move away from hazard (negative X) + assert decision.params["target_position"][0] < 0 + + +class TestTraceRecording: + def test_trace_populated_on_tool_use(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response() + + adapter = _make_adapter(mock_client) + adapter.decide(_make_obs()) + + assert adapter.last_trace is not None + assert adapter.last_trace["parse_method"] == "tool_use" + assert adapter.last_trace["decision"]["tool"] == "move_to" + assert adapter.last_trace["tokens_used"] == 165 # 120 + 45 + + def test_trace_populated_on_fallback(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_text_only_response() + + adapter = _make_adapter(mock_client) + adapter.decide(_make_obs()) + + assert adapter.last_trace is not None + assert adapter.last_trace["parse_method"] == "fallback_no_tool_use" + + def test_trace_populated_on_error(self): + mock_client = MagicMock() + mock_client.messages.create.side_effect = RuntimeError("boom") + + adapter = _make_adapter(mock_client) + adapter.decide(_make_obs()) + + assert adapter.last_trace is not None + assert adapter.last_trace["parse_method"] == "error" + + def test_trace_has_system_and_user_prompt(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response() + + adapter = _make_adapter(mock_client) + adapter.decide(_make_obs()) + + assert adapter.last_trace["system_prompt"] == adapter.system_prompt + assert "Tick: 1" in adapter.last_trace["user_prompt"] + + +class TestAPICallParameters: + def test_messages_create_called_with_tools(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response() + + adapter = _make_adapter(mock_client) + adapter.decide(_make_obs()) + + call_kwargs = mock_client.messages.create.call_args.kwargs + assert call_kwargs["model"] == "claude-sonnet-4-20250514" + assert call_kwargs["max_tokens"] == 1024 + assert isinstance(call_kwargs["tools"], list) + assert len(call_kwargs["tools"]) == 5 # move_to, collect, craft_item, explore, idle + + # Verify tool names + tool_names = {t["name"] for t in call_kwargs["tools"]} + assert tool_names == {"move_to", "collect", "craft_item", "explore", "idle"} + + def test_observation_is_user_message(self): + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response() + + adapter = _make_adapter(mock_client) + adapter.decide(_make_obs(tick=42)) + + call_kwargs = mock_client.messages.create.call_args.kwargs + messages = call_kwargs["messages"] + assert len(messages) == 1 + assert messages[0]["role"] == "user" + assert "Tick: 42" in messages[0]["content"] + + def test_custom_model(self): + adapter = ClaudeAdapter(model="claude-haiku-4-5-20251001", api_key="k") + + mock_client = MagicMock() + mock_client.messages.create.return_value = _mock_tool_use_response() + adapter.client = mock_client + + adapter.decide(_make_obs()) + + call_kwargs = mock_client.messages.create.call_args.kwargs + assert call_kwargs["model"] == "claude-haiku-4-5-20251001" diff --git a/tests/test_arena.py b/tests/test_arena.py deleted file mode 100644 index 31230b0..0000000 --- a/tests/test_arena.py +++ /dev/null @@ -1,311 +0,0 @@ -""" -Tests for AgentArena orchestrator and IPC integration. -""" - -import pytest - -from agent_runtime import AgentArena -from agent_runtime.behavior import AgentBehavior -from agent_runtime.schemas import AgentDecision, Observation - - -class MockBehavior(AgentBehavior): - """Mock agent behavior for testing.""" - - def __init__(self, decision=None): - self.decision = decision or AgentDecision.idle() - self.observations = [] - self.tool_schemas = [] - - def decide(self, observation, tools): - self.observations.append(observation) - self.tool_schemas = tools - return self.decision - - -class TestAgentArena: - """Tests for AgentArena class.""" - - def test_initialization(self): - """Test basic initialization.""" - arena = AgentArena(max_workers=8) - assert arena.runtime is not None - assert arena.behaviors == {} - assert arena.ipc_server is None - assert not arena.is_running() - - def test_default_workers(self): - """Test default worker count.""" - arena = AgentArena() - assert arena.runtime.executor._max_workers == 4 - - def test_custom_workers(self): - """Test custom worker count.""" - arena = AgentArena(max_workers=8) - assert arena.runtime.executor._max_workers == 8 - - def test_register_behavior(self): - """Test registering agent behaviors.""" - arena = AgentArena() - behavior = MockBehavior() - - arena.register('agent_001', behavior) - assert 'agent_001' in arena.behaviors - assert arena.behaviors['agent_001'] == behavior - - def test_register_multiple_behaviors(self): - """Test registering multiple agent behaviors.""" - arena = AgentArena() - behavior1 = MockBehavior() - behavior2 = MockBehavior() - - arena.register('agent_001', behavior1) - arena.register('agent_002', behavior2) - - assert len(arena.behaviors) == 2 - assert arena.behaviors['agent_001'] == behavior1 - assert arena.behaviors['agent_002'] == behavior2 - - def test_register_replaces_existing(self): - """Test that registering same ID replaces existing behavior.""" - arena = AgentArena() - behavior1 = MockBehavior() - behavior2 = MockBehavior() - - arena.register('agent_001', behavior1) - arena.register('agent_001', behavior2) - - assert len(arena.behaviors) == 1 - assert arena.behaviors['agent_001'] == behavior2 - - def test_unregister(self): - """Test unregistering agents.""" - arena = AgentArena() - behavior = MockBehavior() - - arena.register('agent_001', behavior) - assert 'agent_001' in arena.behaviors - - arena.unregister('agent_001') - assert 'agent_001' not in arena.behaviors - - def test_unregister_nonexistent(self): - """Test unregistering nonexistent agent doesn't error.""" - arena = AgentArena() - # Should not raise - arena.unregister('nonexistent') - - def test_get_registered_agents(self): - """Test getting list of registered agents.""" - arena = AgentArena() - behavior1 = MockBehavior() - behavior2 = MockBehavior() - - assert arena.get_registered_agents() == [] - - arena.register('agent_001', behavior1) - arena.register('agent_002', behavior2) - - agents = arena.get_registered_agents() - assert len(agents) == 2 - assert 'agent_001' in agents - assert 'agent_002' in agents - - def test_get_behavior(self): - """Test getting behavior for an agent.""" - arena = AgentArena() - behavior = MockBehavior() - - arena.register('agent_001', behavior) - - assert arena.get_behavior('agent_001') == behavior - assert arena.get_behavior('nonexistent') is None - - def test_is_running(self): - """Test running state tracking.""" - arena = AgentArena() - assert not arena.is_running() - - # We can't easily test running state without actually starting the server - # Just verify the property exists and returns False initially - - def test_run_without_connection_raises(self): - """Test that run() raises if not connected.""" - arena = AgentArena() - - with pytest.raises(RuntimeError, match="Not connected"): - arena.run() - - def test_run_async_without_connection_raises(self): - """Test that run_async() raises if not connected.""" - import asyncio - - arena = AgentArena() - - with pytest.raises(RuntimeError, match="Not connected"): - asyncio.run(arena.run_async()) - - def test_stop(self): - """Test stopping arena.""" - arena = AgentArena() - - # Should not raise even if not running - arena.stop() - - assert not arena.is_running() - - -# Integration tests for IPC converters - -from ipc.converters import decision_to_action, perception_to_observation -from ipc.messages import PerceptionMessage - - -class TestIPCConverters: - """Tests for IPC converter functions.""" - - def test_perception_to_observation_minimal(self): - """Test converting minimal perception to observation.""" - perception = PerceptionMessage( - agent_id="agent_001", - tick=10, - position=[1.0, 2.0, 3.0], - rotation=[0.0, 90.0, 0.0], - ) - - obs = perception_to_observation(perception) - - assert obs.agent_id == "agent_001" - assert obs.tick == 10 - assert obs.position == (1.0, 2.0, 3.0) - assert obs.rotation == (0.0, 90.0, 0.0) - assert obs.health == 100.0 - assert obs.energy == 100.0 - - def test_perception_to_observation_with_custom_data(self): - """Test converting perception with custom data.""" - perception = PerceptionMessage( - agent_id="agent_001", - tick=10, - position=[1.0, 2.0, 3.0], - rotation=[0.0, 0.0, 0.0], - custom_data={ - "nearby_resources": [ - { - "name": "apple", - "type": "food", - "position": [5.0, 0.0, 3.0], - "distance": 4.1, - } - ], - "nearby_hazards": [ - { - "name": "lava", - "type": "environmental", - "position": [10.0, 0.0, 10.0], - "distance": 12.7, - "damage": 50.0, - } - ], - "custom_field": "value", - }, - ) - - obs = perception_to_observation(perception) - - assert len(obs.nearby_resources) == 1 - assert obs.nearby_resources[0].name == "apple" - assert obs.nearby_resources[0].position == (5.0, 0.0, 3.0) - - assert len(obs.nearby_hazards) == 1 - assert obs.nearby_hazards[0].name == "lava" - assert obs.nearby_hazards[0].damage == 50.0 - - assert obs.custom["custom_field"] == "value" - assert "nearby_resources" not in obs.custom - assert "nearby_hazards" not in obs.custom - - def test_perception_to_observation_with_inventory(self): - """Test converting perception with inventory.""" - perception = PerceptionMessage( - agent_id="agent_001", - tick=10, - position=[0.0, 0.0, 0.0], - rotation=[0.0, 0.0, 0.0], - inventory=[ - {"id": "item_1", "name": "sword", "quantity": 1}, - {"id": "item_2", "name": "potion", "quantity": 5}, - ], - ) - - obs = perception_to_observation(perception) - - assert len(obs.inventory) == 2 - assert obs.inventory[0].name == "sword" - assert obs.inventory[0].quantity == 1 - assert obs.inventory[1].name == "potion" - assert obs.inventory[1].quantity == 5 - - def test_perception_to_observation_with_entities(self): - """Test converting perception with visible entities.""" - perception = PerceptionMessage( - agent_id="agent_001", - tick=10, - position=[0.0, 0.0, 0.0], - rotation=[0.0, 0.0, 0.0], - visible_entities=[ - { - "id": "tree_1", - "type": "obstacle", - "position": [3.0, 0.0, 4.0], - "distance": 5.0, - "metadata": {"height": 10}, - } - ], - ) - - obs = perception_to_observation(perception) - - assert len(obs.visible_entities) == 1 - assert obs.visible_entities[0].id == "tree_1" - assert obs.visible_entities[0].type == "obstacle" - assert obs.visible_entities[0].distance == 5.0 - assert obs.visible_entities[0].metadata["height"] == 10 - - def test_decision_to_action(self): - """Test converting decision to action message.""" - decision = AgentDecision( - tool="move_to", - params={"target_position": [10.0, 0.0, 5.0], "speed": 2.0}, - reasoning="Moving to resource", - ) - - action = decision_to_action(decision, "agent_001", 15) - - assert action.agent_id == "agent_001" - assert action.tick == 15 - assert action.tool == "move_to" - assert action.params == {"target_position": [10.0, 0.0, 5.0], "speed": 2.0} - assert action.reasoning == "Moving to resource" - - def test_decision_to_action_idle(self): - """Test converting idle decision to action.""" - decision = AgentDecision.idle() - - action = decision_to_action(decision, "agent_001", 5) - - assert action.tool == "idle" - assert action.params == {} - assert action.reasoning == "" - - def test_decision_to_action_no_reasoning(self): - """Test converting decision without reasoning.""" - decision = AgentDecision( - tool="pickup", - params={"item_id": "apple"}, - ) - - action = decision_to_action(decision, "agent_001", 10) - - assert action.tool == "pickup" - assert action.reasoning == "" diff --git a/tests/test_local_llm_behavior.py b/tests/test_local_llm_behavior.py deleted file mode 100644 index 8e8020b..0000000 --- a/tests/test_local_llm_behavior.py +++ /dev/null @@ -1,558 +0,0 @@ -""" -Tests for LocalLLMBehavior. -""" - -import json -from unittest.mock import Mock, patch - -import pytest - -from agent_runtime.local_llm_behavior import LocalLLMBehavior, create_local_llm_behavior -from agent_runtime.schemas import AgentDecision, HazardInfo, Observation, ResourceInfo, ToolSchema -from backends.base import BackendConfig, GenerationResult - - -class MockBackend: - """Mock backend for testing LocalLLMBehavior.""" - - def __init__(self, config: BackendConfig): - self.config = config - self.available = True - self.generate_calls = [] - self.generate_with_tools_calls = [] - self.mock_response = None # Can be set to override default response - - def is_available(self) -> bool: - return self.available - - def generate(self, prompt: str, temperature=None, max_tokens=None) -> GenerationResult: - self.generate_calls.append((prompt, temperature, max_tokens)) - if self.mock_response: - return self.mock_response - return GenerationResult( - text='{"tool": "idle", "params": {}, "reasoning": "Test decision"}', - tokens_used=50, - finish_reason="stop", - metadata={}, - ) - - def generate_with_tools( - self, prompt: str, tools: list[dict], temperature=None - ) -> GenerationResult: - self.generate_with_tools_calls.append((prompt, tools, temperature)) - if self.mock_response: - return self.mock_response - return GenerationResult( - text='{"tool": "idle", "params": {}, "reasoning": "Test decision"}', - tokens_used=50, - finish_reason="stop", - metadata={}, - ) - - def unload(self) -> None: - pass - - -class TestLocalLLMBehavior: - """Tests for LocalLLMBehavior class.""" - - def test_initialization_with_available_backend(self): - """Test that LocalLLMBehavior can be initialized with an available backend.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - - behavior = LocalLLMBehavior(backend=backend, system_prompt="Test prompt") - - assert behavior.backend is backend - assert behavior.system_prompt == "Test prompt" - assert behavior.memory.capacity == 10 - - def test_initialization_with_unavailable_backend(self): - """Test that LocalLLMBehavior raises error if backend is unavailable.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - backend.available = False - - with pytest.raises(RuntimeError, match="is not available"): - LocalLLMBehavior(backend=backend) - - def test_initialization_with_custom_memory_capacity(self): - """Test initialization with custom memory capacity.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - - behavior = LocalLLMBehavior(backend=backend, memory_capacity=20) - - assert behavior.memory.capacity == 20 - - def test_initialization_with_temperature_and_max_tokens(self): - """Test initialization with custom temperature and max tokens.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - - behavior = LocalLLMBehavior( - backend=backend, temperature=0.8, max_tokens=512 - ) - - assert behavior.temperature == 0.8 - assert behavior.max_tokens == 512 - - def test_decide_calls_backend_with_tools(self): - """Test that decide() calls backend.generate_with_tools().""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation( - agent_id="test_agent", - tick=1, - position=(0.0, 0.0, 0.0), - ) - - tools = [ - ToolSchema( - name="move_to", - description="Move to a position", - parameters={"type": "object"}, - ) - ] - - decision = behavior.decide(observation, tools) - - # Verify backend was called - assert len(backend.generate_with_tools_calls) == 1 - prompt, tools_arg, temp = backend.generate_with_tools_calls[0] - - # Verify prompt contains observation data - assert "test_agent" not in prompt # agent_id not in prompt - assert "Tick: 1" in prompt - assert "Position: (0.0, 0.0, 0.0)" in prompt - - # Verify tools were passed - assert len(tools_arg) == 1 - assert tools_arg[0]["name"] == "move_to" - - # Verify decision was returned - assert isinstance(decision, AgentDecision) - assert decision.tool == "idle" - - def test_decide_stores_observation_in_memory(self): - """Test that decide() stores observation in memory.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend, memory_capacity=5) - - # Make multiple decisions - for i in range(3): - obs = Observation( - agent_id="test_agent", - tick=i, - position=(float(i), 0.0, 0.0), - ) - behavior.decide(obs, []) - - # Verify all observations are in memory - memory_items = behavior.memory.retrieve() - assert len(memory_items) == 3 - # retrieve() returns most recent first - assert memory_items[0].tick == 2 - assert memory_items[2].tick == 0 - - def test_decide_includes_memory_in_prompt(self): - """Test that decide() includes recent observations in prompt.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - # Add some observations to memory - for i in range(3): - obs = Observation( - agent_id="test_agent", - tick=i, - position=(float(i), 0.0, 0.0), - ) - behavior.decide(obs, []) - - # Check that the last call included memory context - prompt, _, _ = backend.generate_with_tools_calls[-1] - assert "## Recent History" in prompt # Updated to match new format - - def test_decide_handles_resources_in_observation(self): - """Test that decide() includes resources in prompt.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation( - agent_id="test_agent", - tick=1, - position=(0.0, 0.0, 0.0), - nearby_resources=[ - ResourceInfo( - name="apple", - type="food", - position=(5.0, 0.0, 0.0), - distance=5.0, - ), - ResourceInfo( - name="wood", - type="material", - position=(10.0, 0.0, 0.0), - distance=10.0, - ), - ], - ) - - behavior.decide(observation, []) - - # Verify resources were included in prompt - prompt, _, _ = backend.generate_with_tools_calls[0] - assert "## Nearby Resources" in prompt - assert "apple (food)" in prompt - assert "wood (material)" in prompt - - def test_decide_handles_hazards_in_observation(self): - """Test that decide() includes hazards in prompt.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation( - agent_id="test_agent", - tick=1, - position=(0.0, 0.0, 0.0), - nearby_hazards=[ - HazardInfo( - name="fire", - type="hazard", - position=(3.0, 0.0, 0.0), - distance=3.0, - damage=10.0, - ), - ], - ) - - behavior.decide(observation, []) - - # Verify hazards were included in prompt - prompt, _, _ = backend.generate_with_tools_calls[0] - assert "## Nearby Hazards" in prompt - assert "fire (hazard)" in prompt - assert "damage: 10.0" in prompt - - def test_parse_decision_from_json_text(self): - """Test parsing decision from JSON text response.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - backend.mock_response = GenerationResult( - text='{"tool": "move_to", "params": {"target_position": [5, 0, 0]}, "reasoning": "Moving to resource"}', - tokens_used=50, - finish_reason="stop", - metadata={}, - ) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation(agent_id="test", tick=0, position=(0.0, 0.0, 0.0)) - decision = behavior.decide(observation, []) - - assert decision.tool == "move_to" - assert decision.params == {"target_position": [5, 0, 0]} - assert decision.reasoning == "Moving to resource" - - def test_parse_decision_from_json_with_markdown(self): - """Test parsing decision from JSON wrapped in markdown code blocks.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - backend.mock_response = GenerationResult( - text='```json\n{"tool": "pickup", "params": {"item_id": "apple"}}\n```', - tokens_used=50, - finish_reason="stop", - metadata={}, - ) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation(agent_id="test", tick=0, position=(0.0, 0.0, 0.0)) - decision = behavior.decide(observation, []) - - assert decision.tool == "pickup" - assert decision.params == {"item_id": "apple"} - - def test_parse_decision_from_native_tool_call(self): - """Test parsing decision from native tool call (e.g., vLLM).""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - backend.mock_response = GenerationResult( - text="Moving to the nearest resource", - tokens_used=50, - finish_reason="stop", - metadata={ - "tool_call": { - "name": "move_to", - "arguments": {"target_position": [10, 0, 5]}, - } - }, - ) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation(agent_id="test", tick=0, position=(0.0, 0.0, 0.0)) - decision = behavior.decide(observation, []) - - assert decision.tool == "move_to" - assert decision.params == {"target_position": [10, 0, 5]} - assert decision.reasoning == "Moving to the nearest resource" - - def test_parse_decision_from_parsed_tool_call(self): - """Test parsing decision from pre-parsed tool call (e.g., llama.cpp).""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - backend.mock_response = GenerationResult( - text='{"tool": "idle", "params": {}}', - tokens_used=50, - finish_reason="stop", - metadata={ - "parsed_tool_call": { - "tool": "move_to", - "params": {"target_position": [5, 0, 0]}, - "reasoning": "Pre-parsed decision", - } - }, - ) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation(agent_id="test", tick=0, position=(0.0, 0.0, 0.0)) - decision = behavior.decide(observation, []) - - assert decision.tool == "move_to" - assert decision.params == {"target_position": [5, 0, 0]} - assert decision.reasoning == "Pre-parsed decision" - - def test_parse_decision_fallback_to_idle(self): - """Test that invalid responses fall back to idle.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - backend.mock_response = GenerationResult( - text="This is not valid JSON", - tokens_used=50, - finish_reason="stop", - metadata={}, - ) - behavior = LocalLLMBehavior(backend=backend) - - observation = Observation(agent_id="test", tick=0, position=(0.0, 0.0, 0.0)) - decision = behavior.decide(observation, []) - - assert decision.tool == "idle" - assert "Parse error" in decision.reasoning - - def test_decide_handles_backend_error(self): - """Test that decide() handles backend errors gracefully.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - # Make backend raise an error - backend.generate_with_tools = Mock(side_effect=RuntimeError("Backend error")) - - observation = Observation( - agent_id="test_agent", - tick=1, - position=(0.0, 0.0, 0.0), - ) - - decision = behavior.decide(observation, []) - - # Should return idle decision with error message - assert decision.tool == "idle" - assert "Error" in decision.reasoning - - def test_on_episode_start_clears_memory(self): - """Test that on_episode_start() clears memory.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - # Add observations to memory - for i in range(3): - obs = Observation(agent_id="test", tick=i, position=(float(i), 0.0, 0.0)) - behavior.decide(obs, []) - - assert len(behavior.memory) == 3 - - # Start new episode - behavior.on_episode_start() - - # Memory should be cleared - assert len(behavior.memory) == 0 - - def test_on_episode_end_logs_metrics(self): - """Test that on_episode_end() is called without errors.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - # Should not raise - behavior.on_episode_end(success=True, metrics={"score": 100}) - behavior.on_episode_end(success=False) - - def test_on_tool_result_logs_result(self): - """Test that on_tool_result() is called without errors.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend) - - # Should not raise - behavior.on_tool_result("move_to", {"success": True}) - - -class TestCreateLocalLLMBehavior: - """Tests for create_local_llm_behavior() factory function.""" - - def test_create_with_model_path_only(self): - """Test creating behavior with just a model_path.""" - # Skip - requires llama-cpp-python to be installed - pytest.skip("Requires llama-cpp-python installation") - - behavior = create_local_llm_behavior(model_path="test_model.gguf") - - assert isinstance(behavior, LocalLLMBehavior) - # Should use default foraging prompt - assert "foraging agent" in behavior.system_prompt.lower() - assert behavior.memory.capacity == 10 - - def test_create_with_custom_system_prompt(self): - """Test creating behavior with custom system prompt.""" - pytest.skip("Requires llama-cpp-python installation") - - behavior = create_local_llm_behavior( - model_path="test_model.gguf", system_prompt="Custom prompt" - ) - - assert behavior.system_prompt == "Custom prompt" - - def test_create_with_custom_memory_capacity(self): - """Test creating behavior with custom memory capacity.""" - pytest.skip("Requires llama-cpp-python installation") - - behavior = create_local_llm_behavior(model_path="test_model.gguf", memory_capacity=20) - - assert behavior.memory.capacity == 20 - - def test_create_with_temperature_and_max_tokens(self): - """Test creating behavior with custom temperature and max tokens.""" - pytest.skip("Requires llama-cpp-python installation") - - behavior = create_local_llm_behavior( - model_path="test_model.gguf", temperature=0.9, max_tokens=1024 - ) - - assert behavior.temperature == 0.9 - assert behavior.max_tokens == 1024 - - def test_create_with_all_parameters(self): - """Test creating behavior with all parameters.""" - pytest.skip("Requires llama-cpp-python installation") - - behavior = create_local_llm_behavior( - model_path="test_model.gguf", - system_prompt="Test prompt", - memory_capacity=15, - temperature=0.5, - max_tokens=512, - ) - - assert behavior.system_prompt == "Test prompt" - assert behavior.memory.capacity == 15 - assert behavior.temperature == 0.5 - assert behavior.max_tokens == 512 - - -class TestLocalLLMBehaviorIntegration: - """Integration tests for LocalLLMBehavior.""" - - def test_full_decision_cycle(self): - """Test a full decision cycle from observation to decision.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - - # Configure backend to return a move_to decision - backend.generate_with_tools = Mock( - return_value=GenerationResult( - text='{"tool": "move_to", "params": {"target_position": [10, 0, 0]}, "reasoning": "Moving to resource"}', - tokens_used=75, - finish_reason="stop", - metadata={}, - ) - ) - - behavior = LocalLLMBehavior( - backend=backend, - system_prompt="You are a foraging agent.", - memory_capacity=5, - ) - - observation = Observation( - agent_id="forager_001", - tick=10, - position=(0.0, 0.0, 0.0), - nearby_resources=[ - ResourceInfo( - name="apple", - type="food", - position=(10.0, 0.0, 0.0), - distance=10.0, - ) - ], - ) - - tools = [ - ToolSchema( - name="move_to", - description="Move to target position", - parameters={"type": "object"}, - ), - ToolSchema( - name="idle", - description="Do nothing", - parameters={"type": "object"}, - ), - ] - - decision = behavior.decide(observation, tools) - - # Verify decision - assert decision.tool == "move_to" - assert decision.params == {"target_position": [10, 0, 0]} - assert decision.reasoning == "Moving to resource" - - # Verify observation was stored - memory_items = behavior.memory.retrieve() - assert len(memory_items) == 1 - assert memory_items[0].tick == 10 - - def test_multiple_decision_cycles_with_memory(self): - """Test multiple decisions with memory building up.""" - config = BackendConfig(model_path="test_model.gguf") - backend = MockBackend(config) - behavior = LocalLLMBehavior(backend=backend, memory_capacity=3) - - tools = [ - ToolSchema(name="move_to", description="Move", parameters={}), - ToolSchema(name="idle", description="Idle", parameters={}), - ] - - # Make 5 decisions - for i in range(5): - obs = Observation( - agent_id="test", - tick=i, - position=(float(i), 0.0, 0.0), - ) - decision = behavior.decide(obs, tools) - assert isinstance(decision, AgentDecision) - - # Memory should only keep last 3 - memory_items = behavior.memory.retrieve() - assert len(memory_items) == 3 - # retrieve() returns most recent first - assert memory_items[0].tick == 4 - assert memory_items[2].tick == 2 diff --git a/tests/test_memory.py b/tests/test_memory.py deleted file mode 100644 index 474a918..0000000 --- a/tests/test_memory.py +++ /dev/null @@ -1,535 +0,0 @@ -""" -Tests for agent memory implementations. -""" - -import pytest - -from agent_runtime.memory import ( - AgentMemory, - RAGMemory, - SlidingWindowMemory, - SummarizingMemory, -) -from agent_runtime.schemas import HazardInfo, ItemInfo, Observation, ResourceInfo - - -class TestAgentMemory: - """Tests for AgentMemory abstract base class.""" - - def test_cannot_instantiate_directly(self): - """Test that AgentMemory cannot be instantiated directly.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - AgentMemory() - - def test_requires_method_implementations(self): - """Test that subclasses must implement all abstract methods.""" - - class IncompleteMemory(AgentMemory): - pass - - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - IncompleteMemory() - - def test_concrete_implementation_works(self): - """Test that concrete implementation can be instantiated.""" - - class ConcreteMemory(AgentMemory): - def __init__(self): - self._obs = [] - - def store(self, observation): - self._obs.append(observation) - - def retrieve(self, query=None, limit=None): - return self._obs - - def summarize(self): - return f"{len(self._obs)} observations" - - def clear(self): - self._obs.clear() - - memory = ConcreteMemory() - assert isinstance(memory, AgentMemory) - - obs = Observation(agent_id="test", tick=0, position=(0.0, 0.0, 0.0)) - memory.store(obs) - assert len(memory) == 1 - assert memory.summarize() == "1 observations" - - memory.clear() - assert len(memory) == 0 - - -class TestSlidingWindowMemory: - """Tests for SlidingWindowMemory implementation.""" - - def test_initialization(self): - """Test basic initialization.""" - memory = SlidingWindowMemory(capacity=5) - assert memory.capacity == 5 - assert len(memory) == 0 - - def test_default_capacity(self): - """Test default capacity value.""" - memory = SlidingWindowMemory() - assert memory.capacity == 10 - - def test_invalid_capacity(self): - """Test that invalid capacity raises error.""" - with pytest.raises(ValueError, match="Capacity must be at least 1"): - SlidingWindowMemory(capacity=0) - - with pytest.raises(ValueError, match="Capacity must be at least 1"): - SlidingWindowMemory(capacity=-1) - - def test_store_single_observation(self): - """Test storing a single observation.""" - memory = SlidingWindowMemory(capacity=5) - obs = Observation(agent_id="agent_1", tick=1, position=(1.0, 0.0, 0.0)) - - memory.store(obs) - assert len(memory) == 1 - - def test_store_multiple_observations(self): - """Test storing multiple observations.""" - memory = SlidingWindowMemory(capacity=5) - - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - assert len(memory) == 3 - - def test_capacity_enforcement(self): - """Test that capacity limit is enforced.""" - memory = SlidingWindowMemory(capacity=3) - - # Store 5 observations - for i in range(5): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Should only keep last 3 - assert len(memory) == 3 - - # Verify oldest were discarded - observations = memory.retrieve() - assert observations[0].tick == 4 # Most recent first - assert observations[1].tick == 3 - assert observations[2].tick == 2 - - def test_retrieve_all(self): - """Test retrieving all observations.""" - memory = SlidingWindowMemory(capacity=5) - - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - observations = memory.retrieve() - assert len(observations) == 3 - # Most recent first - assert observations[0].tick == 2 - assert observations[1].tick == 1 - assert observations[2].tick == 0 - - def test_retrieve_with_limit(self): - """Test retrieving limited number of observations.""" - memory = SlidingWindowMemory(capacity=10) - - for i in range(5): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - observations = memory.retrieve(limit=2) - assert len(observations) == 2 - assert observations[0].tick == 4 # Most recent - assert observations[1].tick == 3 - - def test_retrieve_ignores_query(self): - """Test that retrieve ignores query parameter.""" - memory = SlidingWindowMemory(capacity=5) - - obs = Observation(agent_id="agent_1", tick=1, position=(1.0, 0.0, 0.0)) - memory.store(obs) - - # Query should be ignored - observations = memory.retrieve(query="some query") - assert len(observations) == 1 - - def test_summarize_empty(self): - """Test summarize with no observations.""" - memory = SlidingWindowMemory(capacity=5) - summary = memory.summarize() - assert "No observations" in summary - - def test_summarize_with_observations(self): - """Test summarize with observations.""" - memory = SlidingWindowMemory(capacity=5) - - obs1 = Observation( - agent_id="agent_1", - tick=1, - position=(1.0, 0.0, 0.0), - health=90.0, - energy=80.0, - ) - obs2 = Observation( - agent_id="agent_1", - tick=2, - position=(2.0, 0.0, 0.0), - nearby_resources=[ - ResourceInfo(name="apple", type="food", position=(3.0, 0.0, 0.0), distance=1.0) - ], - health=85.0, - energy=75.0, - ) - - memory.store(obs1) - memory.store(obs2) - - summary = memory.summarize() - assert "Tick 1" in summary - assert "Tick 2" in summary - assert "apple" in summary - assert "Health: 85" in summary - - def test_summarize_with_inventory(self): - """Test summarize includes inventory.""" - memory = SlidingWindowMemory(capacity=5) - - obs = Observation( - agent_id="agent_1", - tick=1, - position=(1.0, 0.0, 0.0), - inventory=[ - ItemInfo(id="item_1", name="sword", quantity=1), - ItemInfo(id="item_2", name="potion", quantity=3), - ], - ) - - memory.store(obs) - summary = memory.summarize() - assert "swordx1" in summary - assert "potionx3" in summary - - def test_summarize_with_hazards(self): - """Test summarize includes hazards.""" - memory = SlidingWindowMemory(capacity=5) - - obs = Observation( - agent_id="agent_1", - tick=1, - position=(1.0, 0.0, 0.0), - nearby_hazards=[ - HazardInfo( - name="lava", type="environmental", position=(5.0, 0.0, 0.0), distance=4.0, damage=50.0 - ) - ], - ) - - memory.store(obs) - summary = memory.summarize() - assert "lava" in summary - assert "damage: 50" in summary - - def test_clear(self): - """Test clearing all memories.""" - memory = SlidingWindowMemory(capacity=5) - - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - assert len(memory) == 3 - - memory.clear() - assert len(memory) == 0 - assert memory.summarize() == "No observations in memory." - - -class MockBackend: - """Mock LLM backend for testing.""" - - def __init__(self, response="Test summary"): - self.response = response - self.last_prompt = None - - def generate(self, prompt): - self.last_prompt = prompt - return self.response - - -class TestSummarizingMemory: - """Tests for SummarizingMemory implementation.""" - - def test_initialization(self): - """Test basic initialization.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - assert memory.backend == backend - assert memory.buffer_capacity == 10 - assert memory.compression_trigger == 5 - assert len(memory) == 0 - - def test_default_parameters(self): - """Test default parameter values.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend) - - assert memory.buffer_capacity == 20 - assert memory.compression_trigger == 15 - - def test_invalid_parameters(self): - """Test that invalid parameters raise errors.""" - backend = MockBackend() - - with pytest.raises(ValueError, match="Buffer capacity must be at least 1"): - SummarizingMemory(backend=backend, buffer_capacity=0) - - with pytest.raises(ValueError, match="Compression trigger must be at least 1"): - SummarizingMemory(backend=backend, compression_trigger=0) - - with pytest.raises(ValueError, match="Compression trigger must be <= buffer capacity"): - SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=15) - - def test_store_single_observation(self): - """Test storing a single observation.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - obs = Observation(agent_id="agent_1", tick=1, position=(1.0, 0.0, 0.0)) - memory.store(obs) - - assert len(memory) == 1 - assert len(memory._buffer) == 1 - - def test_store_below_compression_trigger(self): - """Test storing observations below compression trigger.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - for i in range(4): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - assert len(memory) == 4 - assert len(memory._buffer) == 4 - assert memory._summary == "" # No compression yet - - def test_compression_trigger(self): - """Test that compression is triggered at threshold.""" - backend = MockBackend(response="Compressed summary") - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=3) - - # Store observations up to trigger - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Compression should have been triggered - assert memory._summary == "Compressed summary" - assert backend.last_prompt is not None - assert "Summarize" in backend.last_prompt - - def test_compression_keeps_recent_observations(self): - """Test that compression keeps some recent observations.""" - backend = MockBackend(response="Summary") - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - # Store 5 observations to trigger compression - for i in range(5): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Should keep (buffer_capacity - compression_trigger) = 5 observations - assert len(memory._buffer) == 5 - - def test_multiple_compressions(self): - """Test multiple compression cycles.""" - backend = MockBackend(response="Summary iteration") - memory = SummarizingMemory(backend=backend, buffer_capacity=8, compression_trigger=4) - - # First batch - for i in range(4): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - first_summary = memory._summary - assert "Summary iteration" in first_summary - - # Second batch - for i in range(4, 8): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Should have compressed again - assert memory._summary != "" - - def test_retrieve_returns_buffer_only(self): - """Test that retrieve returns only buffer observations.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - for i in range(7): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Retrieve should only return buffer contents - observations = memory.retrieve() - assert len(observations) <= len(memory._buffer) - - def test_retrieve_with_limit(self): - """Test retrieve with limit parameter.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - observations = memory.retrieve(limit=2) - assert len(observations) == 2 - assert observations[0].tick == 2 # Most recent first - - def test_summarize_with_no_data(self): - """Test summarize with no observations.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend) - - summary = memory.summarize() - assert "No observations" in summary - - def test_summarize_with_buffer_only(self): - """Test summarize with only buffer observations.""" - backend = MockBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=5) - - obs = Observation(agent_id="agent_1", tick=1, position=(1.0, 0.0, 0.0)) - memory.store(obs) - - summary = memory.summarize() - assert "Recent Observations" in summary - assert "Tick 1" in summary - - def test_summarize_with_compressed_and_buffer(self): - """Test summarize includes both compressed summary and buffer.""" - backend = MockBackend(response="Old events summary") - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=3) - - # Trigger compression - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Add new observation - obs = Observation(agent_id="agent_1", tick=10, position=(10.0, 0.0, 0.0)) - memory.store(obs) - - summary = memory.summarize() - assert "Compressed Memory Summary" in summary - assert "Old events summary" in summary - assert "Recent Observations" in summary - - def test_clear(self): - """Test clearing all memories.""" - backend = MockBackend(response="Summary") - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=3) - - # Store and compress - for i in range(5): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - assert len(memory) > 0 - assert memory._summary != "" - - memory.clear() - assert len(memory) == 0 - assert memory._summary == "" - assert len(memory._buffer) == 0 - - def test_fallback_compression(self): - """Test fallback compression when backend fails.""" - - class FailingBackend: - def generate(self, prompt): - raise Exception("Backend error") - - backend = FailingBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=3) - - # Store observations to trigger compression - for i in range(3): - obs = Observation( - agent_id="agent_1", - tick=i, - position=(i, 0.0, 0.0), - nearby_resources=[ - ResourceInfo(name="wood", type="material", position=(i + 1, 0.0, 0.0), distance=1.0) - ], - ) - memory.store(obs) - - # Should use fallback compression - assert memory._summary != "" - assert "Ticks" in memory._summary - - def test_fallback_compression_no_backend_generate(self): - """Test fallback when backend has no generate method.""" - - class NoGenerateBackend: - pass - - backend = NoGenerateBackend() - memory = SummarizingMemory(backend=backend, buffer_capacity=10, compression_trigger=3) - - for i in range(3): - obs = Observation(agent_id="agent_1", tick=i, position=(i, 0.0, 0.0)) - memory.store(obs) - - # Should use fallback compression - assert memory._summary != "" - - -class TestRAGMemory: - """Tests for RAGMemory implementation.""" - - def test_initialization(self): - """Test that RAGMemory initializes correctly.""" - memory = RAGMemory() - assert isinstance(memory, AgentMemory) - assert len(memory) == 0 - - def test_initialization_with_args(self): - """Test that RAGMemory accepts configuration args.""" - memory = RAGMemory( - embedding_model="all-MiniLM-L6-v2", - similarity_threshold=0.5, - default_k=3 - ) - assert memory.similarity_threshold == 0.5 - assert memory.default_k == 3 - - def test_basic_store_and_retrieve(self): - """Test basic store and retrieve functionality.""" - memory = RAGMemory() - - # Store an observation - obs = Observation( - agent_id="test_agent", - tick=1, - position=(0.0, 0.0, 0.0), - health=100.0, - energy=100.0 - ) - memory.store(obs) - - assert len(memory) == 1 - - # Retrieve recent observations - results = memory.retrieve(limit=5) - assert len(results) == 1 - assert results[0].agent_id == "test_agent" diff --git a/tests/test_reasoning_trace.py b/tests/test_reasoning_trace.py index 749571f..2ff1fea 100644 --- a/tests/test_reasoning_trace.py +++ b/tests/test_reasoning_trace.py @@ -1,88 +1,93 @@ """Tests for the reasoning trace system.""" -import importlib.util import json -import sys import tempfile -import time from pathlib import Path import pytest -# Add python directory to path for imports -python_dir = Path(__file__).parent.parent / "python" -sys.path.insert(0, str(python_dir)) - - -def load_module_directly(name: str, file_path: Path): - """Load a module directly from file, adding it to sys.modules.""" - spec = importlib.util.spec_from_file_location(name, file_path) - module = importlib.util.module_from_spec(spec) - sys.modules[name] = module - spec.loader.exec_module(module) - return module +from agent_runtime.reasoning_trace import ( + ReasoningTrace, + TraceStep, + TraceStepName, + TraceStore, + get_global_trace_store, + set_global_trace_store, +) -# Import the reasoning_trace module directly to avoid heavy dependencies in __init__.py -reasoning_trace_module = load_module_directly( - "agent_runtime.reasoning_trace", python_dir / "agent_runtime" / "reasoning_trace.py" -) +class TestTraceStepName: + """Tests for TraceStepName enum.""" -ReasoningTrace = reasoning_trace_module.ReasoningTrace -TraceStep = reasoning_trace_module.TraceStep -TraceStore = reasoning_trace_module.TraceStore + def test_standard_step_names(self): + """Test that standard step names exist.""" + assert TraceStepName.OBSERVATION == "observation" + assert TraceStepName.DECISION == "decision" + assert TraceStepName.PROMPT_BUILDING == "prompt" + assert TraceStepName.LLM_REQUEST == "llm_request" + assert TraceStepName.LLM_RESPONSE == "response" + assert TraceStepName.RETRIEVED == "retrieved" class TestTraceStep: """Tests for TraceStep dataclass.""" def test_create_trace_step(self): - """Test creating a trace step.""" - step = TraceStep(name="test", data={"key": "value"}) - assert step.name == "test" - assert step.data == {"key": "value"} - assert step.timestamp > 0 - assert step.elapsed_ms == 0.0 - - def test_trace_step_to_dict(self): + """Test creating a trace step with all required fields.""" + step = TraceStep( + timestamp="2026-01-01T00:00:00Z", + agent_id="agent1", + tick=42, + name="observation", + data={"position": [1, 2, 3]}, + ) + assert step.name == "observation" + assert step.agent_id == "agent1" + assert step.tick == 42 + assert step.data == {"position": [1, 2, 3]} + assert step.duration_ms is None + + def test_create_with_duration(self): + """Test creating a trace step with duration.""" + step = TraceStep( + timestamp="2026-01-01T00:00:00Z", + agent_id="agent1", + tick=42, + name="llm_request", + data={"prompt": "test"}, + duration_ms=150.5, + ) + assert step.duration_ms == 150.5 + + def test_to_dict(self): """Test converting trace step to dict.""" - step = TraceStep(name="test", data={"key": "value"}, timestamp=1000.0, elapsed_ms=5.0) + step = TraceStep( + timestamp="2026-01-01T00:00:00Z", + agent_id="agent1", + tick=42, + name="test", + data={"key": "value"}, + ) result = step.to_dict() assert result["name"] == "test" assert result["data"] == {"key": "value"} - assert result["timestamp"] == 1000.0 - assert result["elapsed_ms"] == 5.0 - - def test_trace_step_from_dict(self): - """Test creating trace step from dict.""" - data = {"name": "test", "data": {"key": "value"}, "timestamp": 1000.0, "elapsed_ms": 5.0} - step = TraceStep.from_dict(data) - assert step.name == "test" - assert step.data == {"key": "value"} - assert step.timestamp == 1000.0 - assert step.elapsed_ms == 5.0 - - def test_trace_step_serialize_complex_data(self): - """Test serializing complex data types.""" - - class MockObject: - def __init__(self): - self.field = "value" - - step = TraceStep(name="test", data=MockObject()) - result = step.to_dict() - assert result["data"] == {"field": "value"} - - def test_trace_step_serialize_to_dict_method(self): - """Test serializing objects with to_dict method.""" - - class MockDataClass: - def to_dict(self): - return {"custom": "data"} - - step = TraceStep(name="test", data=MockDataClass()) + assert result["timestamp"] == "2026-01-01T00:00:00Z" + assert result["agent_id"] == "agent1" + assert result["tick"] == 42 + assert "duration_ms" not in result # None values omitted + + def test_to_dict_with_duration(self): + """Test to_dict includes duration_ms when set.""" + step = TraceStep( + timestamp="2026-01-01T00:00:00Z", + agent_id="agent1", + tick=42, + name="test", + data={}, + duration_ms=5.0, + ) result = step.to_dict() - assert result["data"] == {"custom": "data"} + assert result["duration_ms"] == 5.0 class TestReasoningTrace: @@ -90,73 +95,93 @@ class TestReasoningTrace: def test_create_trace(self): """Test creating a reasoning trace.""" - trace = ReasoningTrace(agent_id="agent1", tick=42, episode_id="ep1") + trace = ReasoningTrace( + agent_id="agent1", + tick=42, + episode_id="ep1", + start_time="2026-01-01T00:00:00Z", + ) assert trace.agent_id == "agent1" assert trace.tick == 42 assert trace.episode_id == "ep1" + assert trace.start_time == "2026-01-01T00:00:00Z" assert trace.steps == [] - assert trace.trace_id is not None - assert trace.start_time > 0 def test_add_step(self): """Test adding steps to a trace.""" - trace = ReasoningTrace(agent_id="agent1", tick=42, episode_id="ep1") - step = trace.add_step("observation", {"position": [1, 2, 3]}) + trace = ReasoningTrace( + agent_id="agent1", + tick=42, + episode_id="ep1", + start_time="2026-01-01T00:00:00Z", + ) + trace.add_step("observation", {"position": [1, 2, 3]}) assert len(trace.steps) == 1 assert trace.steps[0].name == "observation" assert trace.steps[0].data == {"position": [1, 2, 3]} - assert step.elapsed_ms >= 0 + assert trace.steps[0].agent_id == "agent1" + assert trace.steps[0].tick == 42 + + def test_add_step_with_duration(self): + """Test adding step with explicit duration.""" + trace = ReasoningTrace( + agent_id="agent1", + tick=42, + episode_id="ep1", + start_time="2026-01-01T00:00:00Z", + ) + trace.add_step("llm_request", {"prompt": "test"}, duration_ms=200.0) - def test_add_multiple_steps(self): - """Test adding multiple steps tracks elapsed time.""" - trace = ReasoningTrace(agent_id="agent1", tick=42, episode_id="ep1") + assert trace.steps[0].duration_ms == 200.0 - trace.add_step("step1", "data1") - time.sleep(0.01) # Small delay - trace.add_step("step2", "data2") + def test_add_multiple_steps(self): + """Test adding multiple steps.""" + trace = ReasoningTrace( + agent_id="agent1", + tick=42, + episode_id="ep1", + start_time="2026-01-01T00:00:00Z", + ) + trace.add_step("observation", {"pos": [0, 0, 0]}) + trace.add_step("decision", {"tool": "move_to"}) assert len(trace.steps) == 2 - assert trace.steps[1].elapsed_ms > trace.steps[0].elapsed_ms + assert trace.steps[0].name == "observation" + assert trace.steps[1].name == "decision" def test_to_dict(self): """Test converting trace to dict.""" - trace = ReasoningTrace(agent_id="agent1", tick=42, episode_id="ep1") + trace = ReasoningTrace( + agent_id="agent1", + tick=42, + episode_id="ep1", + start_time="2026-01-01T00:00:00Z", + ) trace.add_step("test", {"key": "value"}) result = trace.to_dict() assert result["agent_id"] == "agent1" assert result["tick"] == 42 assert result["episode_id"] == "ep1" + assert result["start_time"] == "2026-01-01T00:00:00Z" assert len(result["steps"]) == 1 - def test_to_json_and_from_json(self): - """Test JSON serialization round-trip.""" - trace = ReasoningTrace(agent_id="agent1", tick=42, episode_id="ep1") - trace.add_step("observation", {"position": [1, 2, 3]}) - trace.add_step("decision", {"tool": "move_to"}) - - json_str = trace.to_json() - restored = ReasoningTrace.from_json(json_str) - - assert restored.agent_id == trace.agent_id - assert restored.tick == trace.tick - assert restored.episode_id == trace.episode_id - assert len(restored.steps) == 2 - assert restored.steps[0].name == "observation" - assert restored.steps[1].name == "decision" - - def test_format_tree(self): - """Test tree formatting.""" - trace = ReasoningTrace(agent_id="agent1", tick=42, episode_id="ep1") + def test_to_jsonl(self): + """Test JSONL serialization.""" + trace = ReasoningTrace( + agent_id="agent1", + tick=42, + episode_id="ep1", + start_time="2026-01-01T00:00:00Z", + ) trace.add_step("observation", {"position": [1, 2, 3]}) - trace.add_step("decision", {"tool": "move_to", "params": {"target": [4, 5, 6]}}) - tree = trace.format_tree() - assert "Decision Trace - Agent: agent1, Tick: 42" in tree - assert "observation" in tree - assert "decision" in tree - assert "move_to" in tree + jsonl = trace.to_jsonl() + data = json.loads(jsonl) + assert data["agent_id"] == "agent1" + assert data["tick"] == 42 + assert len(data["steps"]) == 1 class TestTraceStore: @@ -169,196 +194,290 @@ def temp_dir(self): yield Path(tmpdir) @pytest.fixture - def store(self, temp_dir): - """Create a trace store in temp directory.""" - # Reset singleton - TraceStore.reset_instance() - return TraceStore(temp_dir) - - def test_create_store(self, temp_dir): - """Test creating a trace store.""" - store = TraceStore(temp_dir) - assert store.traces_dir == temp_dir - assert temp_dir.exists() - - def test_set_episode(self, store): - """Test setting an episode.""" - episode_id = store.set_episode("agent1", "ep_test") - assert episode_id == "ep_test" - assert store.get_episode("agent1") == "ep_test" - - def test_set_episode_auto_generate(self, store): - """Test auto-generating episode ID.""" - episode_id = store.set_episode("agent1") - assert episode_id.startswith("ep_") - assert store.get_episode("agent1") == episode_id - - def test_start_trace(self, store): - """Test starting a trace.""" - store.set_episode("agent1", "ep1") - trace = store.start_trace("agent1", tick=42) + def store(self): + """Create a basic in-memory trace store.""" + return TraceStore(enabled=True, log_to_file=False) + @pytest.fixture + def file_store(self, temp_dir): + """Create a trace store with file logging.""" + return TraceStore(enabled=True, log_to_file=True, log_dir=temp_dir) + + def test_create_store_defaults(self): + """Test creating a trace store with defaults.""" + store = TraceStore() + assert store.enabled is True + assert store.max_entries == 1000 + assert store.log_to_file is False + + def test_create_store_disabled(self): + """Test creating a disabled trace store.""" + store = TraceStore(enabled=False) + assert store.enabled is False + + def test_start_and_get_capture(self, store): + """Test starting and retrieving a capture.""" + trace = store.start_capture("agent1", tick=42) + + assert trace is not None assert trace.agent_id == "agent1" assert trace.tick == 42 - assert trace.episode_id == "ep1" - def test_add_step(self, store): - """Test adding a step.""" - store.set_episode("agent1", "ep1") - step = store.add_step("agent1", tick=42, name="test", data={"key": "value"}) + retrieved = store.get_capture("agent1", 42) + assert retrieved is trace + + def test_start_capture_disabled(self): + """Test start_capture returns None when disabled.""" + store = TraceStore(enabled=False) + trace = store.start_capture("agent1", tick=42) + assert trace is None - assert step is not None - assert step.name == "test" + def test_finish_capture(self, store): + """Test finishing a capture.""" + trace = store.start_capture("agent1", tick=42) + trace.add_step("observation", {"pos": [0, 0, 0]}) + # Should not raise + store.finish_capture("agent1", 42) - def test_end_trace_writes_file(self, store, temp_dir): - """Test ending a trace writes to file.""" - store.set_episode("agent1", "ep1") - store.add_step("agent1", tick=42, name="test", data={"key": "value"}) - trace = store.end_trace("agent1") + def test_finish_capture_with_file(self, file_store, temp_dir): + """Test finishing a capture writes to JSONL file.""" + trace = file_store.start_capture("agent1", tick=42) + trace.add_step("test", {"key": "value"}) + file_store.finish_capture("agent1", 42) - assert trace is not None - trace_file = temp_dir / "agent1" / "ep1.jsonl" - assert trace_file.exists() + # Check that the episode file was written + episode_file = temp_dir / f"{file_store.episode_id}.jsonl" + assert episode_file.exists() - with open(trace_file) as f: + # Close file handle before reading (Windows file locking) + file_store.end_episode() + + with open(episode_file) as f: lines = f.readlines() assert len(lines) == 1 data = json.loads(lines[0]) assert data["agent_id"] == "agent1" assert data["tick"] == 42 - def test_get_last_decision(self, store, temp_dir): - """Test getting the last decision.""" - store.set_episode("agent1", "ep1") + def test_get_captures_for_agent(self, store): + """Test getting all captures for an agent.""" + store.start_capture("agent1", tick=1) + store.start_capture("agent1", tick=2) + store.start_capture("agent2", tick=1) - # Add multiple traces - store.add_step("agent1", tick=1, name="step1", data="data1") - store.end_trace("agent1") + traces = store.get_captures_for_agent("agent1") + assert len(traces) == 2 + assert traces[0].tick == 1 + assert traces[1].tick == 2 - store.add_step("agent1", tick=2, name="step2", data="data2") - store.end_trace("agent1") + def test_get_all_captures(self, store): + """Test getting all captures.""" + store.start_capture("agent1", tick=1) + store.start_capture("agent2", tick=1) + store.start_capture("agent1", tick=2) + + traces = store.get_all_captures() + assert len(traces) == 3 + + def test_get_all_captures_with_tick_range(self, store): + """Test filtering captures by tick range.""" + store.start_capture("agent1", tick=1) + store.start_capture("agent1", tick=5) + store.start_capture("agent1", tick=10) + + traces = store.get_all_captures(tick_start=3, tick_end=7) + assert len(traces) == 1 + assert traces[0].tick == 5 + + def test_max_entries_limit(self): + """Test that max_entries limit is enforced.""" + store = TraceStore(max_entries=3) + store.start_capture("agent1", tick=1) + store.start_capture("agent1", tick=2) + store.start_capture("agent1", tick=3) + store.start_capture("agent1", tick=4) + + assert len(store.traces) == 3 + # Oldest (tick=1) should have been evicted + assert store.get_capture("agent1", 1) is None + assert store.get_capture("agent1", 4) is not None + + def test_start_and_end_episode(self, file_store, temp_dir): + """Test episode lifecycle.""" + file_store.start_episode("ep_custom") + assert file_store.episode_id == "ep_custom" + + trace = file_store.start_capture("agent1", tick=1) + trace.add_step("test", {}) + file_store.finish_capture("agent1", 1) + + file_store.end_episode() + + # Verify file was written + episode_file = temp_dir / "ep_custom.jsonl" + assert episode_file.exists() + + def test_get_episode_traces(self, file_store, temp_dir): + """Test reading traces back from episode file.""" + file_store.start_episode("ep_read_test") + + trace = file_store.start_capture("agent1", tick=1) + trace.add_step("obs", {"pos": [0, 0, 0]}) + file_store.finish_capture("agent1", 1) + + trace2 = file_store.start_capture("agent1", tick=2) + trace2.add_step("dec", {"tool": "idle"}) + file_store.finish_capture("agent1", 2) + + file_store.end_episode() + + # Read back from file + traces = file_store.get_episode_traces("ep_read_test") + assert len(traces) == 2 + assert traces[0].tick == 1 + assert traces[1].tick == 2 - last = store.get_last_decision("agent1") - assert last is not None - assert last.tick == 2 + def test_watch_callback(self, store): + """Test watcher callbacks on finish_capture.""" + received = [] - def test_get_episode_traces(self, store): - """Test getting all traces for an episode.""" - store.set_episode("agent1", "ep1") + def on_trace(trace): + received.append(trace) - store.add_step("agent1", tick=1, name="s1", data="d1") - store.end_trace("agent1") + store.watch("agent1", on_trace) - store.add_step("agent1", tick=2, name="s2", data="d2") - store.end_trace("agent1") + trace = store.start_capture("agent1", tick=1) + trace.add_step("test", {}) + store.finish_capture("agent1", 1) - traces = store.get_episode_traces("agent1", "ep1") - assert len(traces) == 2 - assert traces[0].tick == 1 - assert traces[1].tick == 2 + assert len(received) == 1 + assert received[0].agent_id == "agent1" - def test_list_agents(self, store): - """Test listing agents.""" - store.set_episode("agent1", "ep1") - store.add_step("agent1", tick=1, name="s1", data="d1") - store.end_trace("agent1") + def test_watch_wildcard(self, store): + """Test wildcard watcher receives all agents.""" + received = [] + store.watch("*", lambda t: received.append(t)) - store.set_episode("agent2", "ep1") - store.add_step("agent2", tick=1, name="s1", data="d1") - store.end_trace("agent2") + store.start_capture("agent1", tick=1) + store.finish_capture("agent1", 1) - agents = store.list_agents() - assert set(agents) == {"agent1", "agent2"} + store.start_capture("agent2", tick=1) + store.finish_capture("agent2", 1) - def test_list_episodes(self, store): - """Test listing episodes.""" - store.set_episode("agent1", "ep1") - store.add_step("agent1", tick=1, name="s1", data="d1") - store.end_trace("agent1") + assert len(received) == 2 - store.set_episode("agent1", "ep2") - store.add_step("agent1", tick=1, name="s1", data="d1") - store.end_trace("agent1") + def test_unwatch(self, store): + """Test removing a watcher.""" + received = [] - episodes = store.list_episodes("agent1") - assert "ep1" in episodes - assert "ep2" in episodes + def callback(trace): + received.append(trace) - def test_no_traces_returns_none(self, store): - """Test getting traces when none exist.""" - assert store.get_last_decision("nonexistent") is None - assert store.get_episode_traces("agent1", "nonexistent") == [] + store.watch("agent1", callback) + store.unwatch("agent1", callback) - def test_singleton_pattern(self, temp_dir): - """Test singleton pattern.""" - TraceStore.reset_instance() - store1 = TraceStore.get_instance(temp_dir) - store2 = TraceStore.get_instance() - assert store1 is store2 + store.start_capture("agent1", tick=1) + store.finish_capture("agent1", 1) - TraceStore.reset_instance() + assert len(received) == 0 + def test_clear(self, store): + """Test clearing all in-memory traces.""" + store.start_capture("agent1", tick=1) + store.start_capture("agent1", tick=2) + assert len(store.traces) == 2 -class TestAgentBehaviorTracing: - """Tests for tracing integration with AgentBehavior.""" + store.clear() + assert len(store.traces) == 0 - @pytest.fixture - def temp_dir(self): - """Create a temporary directory for traces.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield Path(tmpdir) + def test_to_json(self, store): + """Test JSON export.""" + trace = store.start_capture("agent1", tick=1) + trace.add_step("obs", {"pos": [0, 0, 0]}) + + json_str = store.to_json() + data = json.loads(json_str) + assert len(data) == 1 + assert data[0]["agent_id"] == "agent1" + + def test_to_json_filter_by_agent(self, store): + """Test JSON export filtered by agent.""" + store.start_capture("agent1", tick=1) + store.start_capture("agent2", tick=1) + + json_str = store.to_json(agent_id="agent1") + data = json.loads(json_str) + assert len(data) == 1 + assert data[0]["agent_id"] == "agent1" + + def test_get_capture_nonexistent(self, store): + """Test getting a nonexistent capture returns None.""" + assert store.get_capture("agent1", 999) is None - def test_enable_tracing(self, temp_dir): - """Test enabling tracing on an agent behavior.""" - # Import directly to avoid heavy dependencies - schemas_module = load_module_directly( - "agent_runtime.schemas", python_dir / "agent_runtime" / "schemas.py" - ) - behavior_module = load_module_directly( - "agent_runtime.behavior", python_dir / "agent_runtime" / "behavior.py" - ) - AgentBehavior = behavior_module.AgentBehavior # noqa: N806 - AgentDecision = schemas_module.AgentDecision # noqa: N806 - Observation = schemas_module.Observation # noqa: N806 +class TestGlobalTraceStore: + """Tests for global singleton functions.""" + + def test_get_global_creates_default(self): + """Test get_global_trace_store creates a default store.""" + # Reset global + set_global_trace_store(None) + + # Import and reset the global + import agent_runtime.reasoning_trace as rt + + rt._global_trace_store = None + + store = get_global_trace_store() + assert store is not None + assert isinstance(store, TraceStore) + + def test_set_and_get_global(self): + """Test setting and getting a custom global store.""" + custom_store = TraceStore(enabled=False, max_entries=50) + set_global_trace_store(custom_store) + + retrieved = get_global_trace_store() + assert retrieved is custom_store + assert retrieved.enabled is False + assert retrieved.max_entries == 50 + + +class TestAgentBehaviorTracing: + """Tests for tracing integration with AgentBehavior.""" + + def test_log_step_with_tracing(self): + """Test log_step adds to current trace when tracing is active.""" + from agent_runtime.behavior import AgentBehavior + from agent_runtime.schemas import AgentDecision, Observation class TestAgent(AgentBehavior): def decide(self, observation, tools): self.log_step("test_step", {"data": "value"}) return AgentDecision.idle() - # Reset singleton to use our temp dir - TraceStore.reset_instance() - + store = TraceStore(enabled=True, log_to_file=False) agent = TestAgent() - agent.enable_tracing(temp_dir) - agent._agent_id = "test_agent" - agent._current_tick = 1 - - # Create a mock observation - obs = Observation( - agent_id="test_agent", - tick=1, - position=(0, 0, 0), - ) + agent._trace_store = store - # Make a decision + # Simulate what IPC server does + agent._set_trace_context("test_agent", 1) + + obs = Observation(agent_id="test_agent", tick=1, position=(0, 0, 0)) agent.decide(obs, []) agent._end_trace() - # Check trace was written - trace_file = list(temp_dir.glob("test_agent/*.jsonl")) - assert len(trace_file) == 1 - - # Cleanup - TraceStore.reset_instance() + # Verify the trace was captured + trace = store.get_capture("test_agent", 1) + assert trace is not None + assert len(trace.steps) == 1 + assert trace.steps[0].name == "test_step" + assert trace.steps[0].data == {"data": "value"} - def test_log_step_without_tracing(self, temp_dir): - """Test log_step is no-op when tracing is disabled.""" - # Use already loaded modules - AgentBehavior = sys.modules["agent_runtime.behavior"].AgentBehavior # noqa: N806 - AgentDecision = sys.modules["agent_runtime.schemas"].AgentDecision # noqa: N806 - Observation = sys.modules["agent_runtime.schemas"].Observation # noqa: N806 + def test_log_step_without_tracing(self): + """Test log_step is a no-op when tracing is disabled.""" + from agent_runtime.behavior import AgentBehavior + from agent_runtime.schemas import AgentDecision, Observation class TestAgent(AgentBehavior): def decide(self, observation, tools): @@ -366,17 +485,8 @@ def decide(self, observation, tools): return AgentDecision.idle() agent = TestAgent() - # Don't enable tracing - - obs = Observation( - agent_id="test_agent", - tick=1, - position=(0, 0, 0), - ) + # Don't set _trace_store + obs = Observation(agent_id="test_agent", tick=1, position=(0, 0, 0)) # Should not raise agent.decide(obs, []) - - # No trace files should exist in temp_dir - trace_files = list(temp_dir.glob("**/*.jsonl")) - assert len(trace_files) == 0 diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 8593786..d452220 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -716,11 +716,13 @@ def test_from_llm_response_invalid_json(self): with pytest.raises(ValueError, match="Invalid JSON"): AgentDecision.from_llm_response(response) - def test_from_llm_response_malformed_json(self): - """Test handling of malformed JSON.""" + def test_from_llm_response_truncated_json_recovery(self): + """Test recovery of truncated JSON (e.g., from LLM token limit).""" response = '{"tool": "move", "params": {' - with pytest.raises(ValueError, match="Invalid JSON"): - AgentDecision.from_llm_response(response) + # Truncated JSON is recovered by _recover_truncated_json + decision = AgentDecision.from_llm_response(response) + assert decision.tool == "move" + assert decision.params == {} def test_from_llm_response_alternate_reasoning_fields(self): """Test parsing with alternate reasoning field names.""" diff --git a/tests/test_vllm_backend.py b/tests/test_vllm_backend.py deleted file mode 100644 index b68f198..0000000 --- a/tests/test_vllm_backend.py +++ /dev/null @@ -1,196 +0,0 @@ -""" -Tests for vLLM backend. - -Note: These tests require a running vLLM server. -Use pytest markers to skip if server is not available. -""" - -import pytest -from backends.vllm_backend import VLLMBackend, VLLMBackendConfig - - -@pytest.fixture -def vllm_config(): - """Create a vLLM config for testing.""" - return VLLMBackendConfig( - model_path="meta-llama/Llama-2-7b-chat-hf", - api_base="http://localhost:8000/v1", - temperature=0.7, - max_tokens=100, - ) - - -@pytest.fixture -def vllm_backend(vllm_config): - """Create a vLLM backend instance.""" - try: - backend = VLLMBackend(vllm_config) - if not backend.is_available(): - pytest.skip("vLLM server not available") - return backend - except Exception as e: - pytest.skip(f"Could not connect to vLLM server: {e}") - - -def test_vllm_config_creation(): - """Test vLLM config initialization.""" - config = VLLMBackendConfig( - model_path="test-model", - api_base="http://test:8000/v1", - api_key="test-key", - temperature=0.5, - max_tokens=256, - ) - - assert config.model_path == "test-model" - assert config.api_base == "http://test:8000/v1" - assert config.api_key == "test-key" - assert config.temperature == 0.5 - assert config.max_tokens == 256 - - -def test_vllm_backend_initialization(vllm_config): - """Test vLLM backend can be initialized.""" - try: - backend = VLLMBackend(vllm_config) - assert backend.client is not None - assert backend.config == vllm_config - except Exception: - pytest.skip("vLLM server not available") - - -def test_vllm_is_available(vllm_backend): - """Test availability check.""" - assert vllm_backend.is_available() is True - - -def test_vllm_generate(vllm_backend): - """Test basic text generation.""" - prompt = "Hello, my name is" - result = vllm_backend.generate(prompt, max_tokens=20) - - assert result is not None - assert len(result.text) > 0 - assert result.tokens_used > 0 - assert result.finish_reason in ["stop", "length"] - assert "model" in result.metadata - - -def test_vllm_generate_with_temperature(vllm_backend): - """Test generation with custom temperature.""" - prompt = "The weather today is" - result = vllm_backend.generate(prompt, temperature=0.1, max_tokens=20) - - assert result is not None - assert len(result.text) > 0 - assert result.finish_reason in ["stop", "length"] - - -def test_vllm_generate_with_tools(vllm_backend): - """Test tool calling generation.""" - prompt = "I need to move to coordinates (10, 20, 5)" - - tools = [ - { - "name": "move_to", - "description": "Move agent to target coordinates", - "parameters": { - "type": "object", - "properties": { - "target": { - "type": "array", - "items": {"type": "number"}, - "description": "Target [x, y, z] coordinates", - } - }, - "required": ["target"], - }, - } - ] - - result = vllm_backend.generate_with_tools(prompt, tools) - - assert result is not None - # Result should contain either a tool call or text response - assert len(result.text) > 0 or "tool_call" in result.metadata - - -def test_vllm_generate_error_handling(vllm_backend): - """Test error handling with invalid input.""" - # Empty prompt should still work - result = vllm_backend.generate("", max_tokens=10) - assert result is not None - assert result.finish_reason in ["stop", "length", "error"] - - -def test_vllm_unload(vllm_config): - """Test unloading backend.""" - try: - backend = VLLMBackend(vllm_config) - backend.unload() - assert backend.client is None - assert backend.is_available() is False - except Exception: - pytest.skip("vLLM server not available") - - -def test_vllm_multiple_generations(vllm_backend): - """Test multiple sequential generations.""" - prompts = ["Hello", "How are you?", "What is AI?"] - - for prompt in prompts: - result = vllm_backend.generate(prompt, max_tokens=20) - assert result is not None - assert len(result.text) > 0 - - -def test_vllm_generate_with_tools_fallback(vllm_backend): - """Test fallback tool calling method.""" - prompt = "Pick up the sword item" - - tools = [ - { - "name": "pickup_item", - "description": "Pick up an item from the world", - "parameters": { - "type": "object", - "properties": { - "item_name": { - "type": "string", - "description": "Name of the item to pick up", - } - }, - "required": ["item_name"], - }, - } - ] - - # Test the fallback method directly - result = vllm_backend._generate_with_tools_fallback(prompt, tools, temperature=0.7) - - assert result is not None - assert len(result.text) > 0 - - -@pytest.mark.parametrize("max_tokens", [10, 50, 100]) -def test_vllm_different_max_tokens(vllm_backend, max_tokens): - """Test generation with different max token limits.""" - prompt = "Once upon a time" - result = vllm_backend.generate(prompt, max_tokens=max_tokens) - - assert result is not None - assert result.tokens_used <= max_tokens * 1.5 # Some tolerance - - -@pytest.mark.parametrize("temperature", [0.1, 0.7, 1.0]) -def test_vllm_different_temperatures(vllm_backend, temperature): - """Test generation with different temperatures.""" - prompt = "The capital of France is" - result = vllm_backend.generate(prompt, temperature=temperature, max_tokens=20) - - assert result is not None - assert len(result.text) > 0 - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"])