From 13e2a08d87b3fabfc6b437f6c1f555c0f654a536 Mon Sep 17 00:00:00 2001 From: minorun365 Date: Sun, 18 Jan 2026 19:27:48 +0900 Subject: [PATCH] Add A2A (Agent-to-Agent) protocol support - Add BedrockAgentCoreA2AApp class for hosting A2A agents - Add A2A models (AgentCard, AgentSkill, JsonRpcRequest/Response, etc.) - Add @entrypoint decorator for message handling - Support JSON-RPC 2.0 protocol on port 9000 - Export A2A classes from runtime module - Add comprehensive unit tests (54 tests) --- src/bedrock_agentcore/runtime/__init__.py | 31 +- src/bedrock_agentcore/runtime/a2a_app.py | 469 ++++++++++++++++ src/bedrock_agentcore/runtime/a2a_models.py | 284 ++++++++++ .../bedrock_agentcore/runtime/test_a2a_app.py | 507 ++++++++++++++++++ .../runtime/test_a2a_models.py | 340 ++++++++++++ 5 files changed, 1630 insertions(+), 1 deletion(-) create mode 100644 src/bedrock_agentcore/runtime/a2a_app.py create mode 100644 src/bedrock_agentcore/runtime/a2a_models.py create mode 100644 tests/bedrock_agentcore/runtime/test_a2a_app.py create mode 100644 tests/bedrock_agentcore/runtime/test_a2a_models.py diff --git a/src/bedrock_agentcore/runtime/__init__.py b/src/bedrock_agentcore/runtime/__init__.py index b86c8aa..3a979e2 100644 --- a/src/bedrock_agentcore/runtime/__init__.py +++ b/src/bedrock_agentcore/runtime/__init__.py @@ -1,19 +1,48 @@ """BedrockAgentCore Runtime Package. This package contains the core runtime components for Bedrock AgentCore applications: -- BedrockAgentCoreApp: Main application class +- BedrockAgentCoreApp: Main application class for HTTP protocol +- BedrockAgentCoreA2AApp: Application class for A2A (Agent-to-Agent) protocol - RequestContext: HTTP request context - BedrockAgentCoreContext: Agent identity context +- AgentCard, AgentSkill: A2A protocol metadata models """ +from .a2a_app import BedrockAgentCoreA2AApp +from .a2a_models import ( + A2A_DEFAULT_PORT, + A2AArtifact, + A2AMessage, + A2AMessagePart, + AgentCard, + AgentSkill, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + build_runtime_url, +) from .agent_core_runtime_client import AgentCoreRuntimeClient from .app import BedrockAgentCoreApp from .context import BedrockAgentCoreContext, RequestContext from .models import PingStatus __all__ = [ + # HTTP Protocol "AgentCoreRuntimeClient", "BedrockAgentCoreApp", + # A2A Protocol + "BedrockAgentCoreA2AApp", + "AgentCard", + "AgentSkill", + "A2AMessage", + "A2AMessagePart", + "A2AArtifact", + "JsonRpcRequest", + "JsonRpcResponse", + "JsonRpcErrorCode", + "A2A_DEFAULT_PORT", + "build_runtime_url", + # Common "RequestContext", "BedrockAgentCoreContext", "PingStatus", diff --git a/src/bedrock_agentcore/runtime/a2a_app.py b/src/bedrock_agentcore/runtime/a2a_app.py new file mode 100644 index 0000000..26670d1 --- /dev/null +++ b/src/bedrock_agentcore/runtime/a2a_app.py @@ -0,0 +1,469 @@ +"""Bedrock AgentCore A2A application implementation. + +Provides a Starlette-based web server for A2A (Agent-to-Agent) protocol communication. +""" + +import asyncio +import contextvars +import inspect +import json +import logging +import os +import threading +import time +import uuid +from collections.abc import Sequence +from typing import Any, Callable, Dict, Optional + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.responses import JSONResponse, StreamingResponse +from starlette.routing import Route +from starlette.types import Lifespan + +from .a2a_models import ( + A2A_DEFAULT_PORT, + AgentCard, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, +) +from .context import BedrockAgentCoreContext, RequestContext +from .models import ( + ACCESS_TOKEN_HEADER, + AUTHORIZATION_HEADER, + CUSTOM_HEADER_PREFIX, + OAUTH2_CALLBACK_URL_HEADER, + REQUEST_ID_HEADER, + SESSION_HEADER, + PingStatus, +) + + +class A2ARequestContextFormatter(logging.Formatter): + """Formatter including request and session IDs for A2A applications.""" + + def format(self, record): + """Format log record as AWS Lambda JSON.""" + from datetime import datetime, timezone + + log_entry = { + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z", + "level": record.levelname, + "message": record.getMessage(), + "logger": record.name, + "protocol": "A2A", + } + + request_id = BedrockAgentCoreContext.get_request_id() + if request_id: + log_entry["requestId"] = request_id + + session_id = BedrockAgentCoreContext.get_session_id() + if session_id: + log_entry["sessionId"] = session_id + + if record.exc_info: + import traceback + + log_entry["errorType"] = record.exc_info[0].__name__ + log_entry["errorMessage"] = str(record.exc_info[1]) + log_entry["stackTrace"] = traceback.format_exception(*record.exc_info) + log_entry["location"] = f"{record.pathname}:{record.funcName}:{record.lineno}" + + return json.dumps(log_entry, ensure_ascii=False) + + +class BedrockAgentCoreA2AApp(Starlette): + """Bedrock AgentCore A2A application class for agent-to-agent communication. + + This class implements the A2A protocol contract for AgentCore Runtime, + supporting JSON-RPC 2.0 messaging and agent discovery via Agent Cards. + + Example: + ```python + from bedrock_agentcore.runtime import BedrockAgentCoreA2AApp, AgentCard, AgentSkill + + agent_card = AgentCard( + name="Calculator Agent", + description="A calculator agent", + skills=[AgentSkill(id="calc", name="Calculator", description="Math ops")] + ) + + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.entrypoint + def handle_message(request, context): + # Process JSON-RPC request + message = request.params["message"] + user_text = message["parts"][0]["text"] + + # Return result (will be wrapped in JSON-RPC response) + return { + "artifacts": [{ + "artifactId": str(uuid.uuid4()), + "name": "response", + "parts": [{"kind": "text", "text": f"Result: {user_text}"}] + }] + } + + app.run() # Runs on port 9000 + ``` + """ + + def __init__( + self, + agent_card: AgentCard, + debug: bool = False, + lifespan: Optional[Lifespan] = None, + middleware: Sequence[Middleware] | None = None, + ): + """Initialize Bedrock AgentCore A2A application. + + Args: + agent_card: AgentCard containing agent metadata for discovery + debug: Enable debug mode for verbose logging (default: False) + lifespan: Optional lifespan context manager for startup/shutdown + middleware: Optional sequence of Starlette Middleware objects + """ + self.agent_card = agent_card + self.handlers: Dict[str, Callable] = {} + self._ping_handler: Optional[Callable] = None + self._active_tasks: Dict[int, Dict[str, Any]] = {} + self._task_counter_lock: threading.Lock = threading.Lock() + self._forced_ping_status: Optional[PingStatus] = None + self._last_status_update_time: float = time.time() + + routes = [ + Route("/", self._handle_jsonrpc, methods=["POST"]), + Route("/.well-known/agent-card.json", self._handle_agent_card, methods=["GET"]), + Route("/ping", self._handle_ping, methods=["GET"]), + ] + super().__init__(routes=routes, lifespan=lifespan, middleware=middleware) + self.debug = debug + + self.logger = logging.getLogger("bedrock_agentcore.a2a_app") + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = A2ARequestContextFormatter() + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(logging.DEBUG if self.debug else logging.INFO) + + def entrypoint(self, func: Callable) -> Callable: + """Decorator to register a function as the main message handler. + + The handler receives the JSON-RPC request and context, and should return + a result that will be wrapped in a JSON-RPC response. + + Args: + func: The function to register as entrypoint. + Signature: (request: JsonRpcRequest, context: RequestContext) -> Any + Or for streaming: async generator yielding response chunks + + Returns: + The decorated function with added run method + """ + self.handlers["main"] = func + func.run = lambda port=A2A_DEFAULT_PORT, host=None: self.run(port, host) + return func + + def ping(self, func: Callable) -> Callable: + """Decorator to register a custom ping status handler. + + Args: + func: The function to register as ping status handler + + Returns: + The decorated function + """ + self._ping_handler = func + return func + + def get_current_ping_status(self) -> PingStatus: + """Get current ping status (forced > custom > automatic).""" + current_status = None + + if self._forced_ping_status is not None: + current_status = self._forced_ping_status + elif self._ping_handler: + try: + result = self._ping_handler() + if isinstance(result, str): + current_status = PingStatus(result) + else: + current_status = result + except Exception as e: + self.logger.warning( + "Custom ping handler failed, falling back to automatic: %s: %s", type(e).__name__, e + ) + + if current_status is None: + current_status = PingStatus.HEALTHY_BUSY if self._active_tasks else PingStatus.HEALTHY + + if not hasattr(self, "_last_known_status") or self._last_known_status != current_status: + self._last_known_status = current_status + self._last_status_update_time = time.time() + + return current_status + + def _get_runtime_url(self) -> Optional[str]: + """Get the runtime URL from environment variable. + + Returns: + The runtime URL if set, None otherwise. + """ + return os.environ.get("AGENTCORE_RUNTIME_URL") + + def _build_request_context(self, request) -> RequestContext: + """Build request context and setup all context variables.""" + try: + headers = request.headers + request_id = headers.get(REQUEST_ID_HEADER) + if not request_id: + request_id = str(uuid.uuid4()) + + session_id = headers.get(SESSION_HEADER) + BedrockAgentCoreContext.set_request_context(request_id, session_id) + + agent_identity_token = headers.get(ACCESS_TOKEN_HEADER) + if agent_identity_token: + BedrockAgentCoreContext.set_workload_access_token(agent_identity_token) + + oauth2_callback_url = headers.get(OAUTH2_CALLBACK_URL_HEADER) + if oauth2_callback_url: + BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) + + # Collect relevant request headers + request_headers = {} + + authorization_header = headers.get(AUTHORIZATION_HEADER) + if authorization_header is not None: + request_headers[AUTHORIZATION_HEADER] = authorization_header + + for header_name, header_value in headers.items(): + if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()): + request_headers[header_name] = header_value + + if request_headers: + BedrockAgentCoreContext.set_request_headers(request_headers) + + req_headers = BedrockAgentCoreContext.get_request_headers() + + return RequestContext( + session_id=session_id, + request_headers=req_headers, + request=request, + ) + except Exception as e: + self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) + request_id = str(uuid.uuid4()) + BedrockAgentCoreContext.set_request_context(request_id, None) + return RequestContext(session_id=None, request=None) + + def _takes_context(self, handler: Callable) -> bool: + """Check if handler accepts context parameter.""" + try: + params = list(inspect.signature(handler).parameters.keys()) + return len(params) >= 2 and params[1] == "context" + except Exception: + return False + + async def _handle_jsonrpc(self, request): + """Handle JSON-RPC 2.0 requests at root endpoint.""" + request_context = self._build_request_context(request) + start_time = time.time() + + try: + body = await request.json() + self.logger.debug("Processing JSON-RPC request: %s", body.get("method", "unknown")) + + # Validate JSON-RPC format + if body.get("jsonrpc") != "2.0": + return self._jsonrpc_error_response( + body.get("id"), + JsonRpcErrorCode.INVALID_REQUEST, + "Invalid JSON-RPC version", + ) + + method = body.get("method") + if not method: + return self._jsonrpc_error_response( + body.get("id"), + JsonRpcErrorCode.INVALID_REQUEST, + "Missing method", + ) + + jsonrpc_request = JsonRpcRequest.from_dict(body) + + handler = self.handlers.get("main") + if not handler: + self.logger.error("No entrypoint defined") + return self._jsonrpc_error_response( + jsonrpc_request.id, + JsonRpcErrorCode.INTERNAL_ERROR, + "No entrypoint defined", + ) + + takes_context = self._takes_context(handler) + + self.logger.debug("Invoking handler for method: %s", method) + result = await self._invoke_handler(handler, request_context, takes_context, jsonrpc_request) + + duration = time.time() - start_time + + # Handle streaming responses + if inspect.isasyncgen(result): + self.logger.info("Returning streaming response (%.3fs)", duration) + return StreamingResponse( + self._stream_jsonrpc_response(result, jsonrpc_request.id), + media_type="text/event-stream", + ) + elif inspect.isgenerator(result): + self.logger.info("Returning streaming response (sync generator) (%.3fs)", duration) + return StreamingResponse( + self._sync_stream_jsonrpc_response(result, jsonrpc_request.id), + media_type="text/event-stream", + ) + + # Non-streaming response + self.logger.info("Request completed successfully (%.3fs)", duration) + response = JsonRpcResponse.success(jsonrpc_request.id, result) + return JSONResponse(response.to_dict()) + + except json.JSONDecodeError as e: + duration = time.time() - start_time + self.logger.warning("Invalid JSON in request (%.3fs): %s", duration, e) + return self._jsonrpc_error_response( + None, + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {str(e)}", + ) + except Exception as e: + duration = time.time() - start_time + self.logger.exception("Request failed (%.3fs)", duration) + return self._jsonrpc_error_response( + body.get("id") if "body" in dir() else None, + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + ) + + def _jsonrpc_error_response( + self, + request_id: Optional[str], + code: int, + message: str, + data: Optional[Any] = None, + ) -> JSONResponse: + """Create a JSON-RPC error response.""" + response = JsonRpcResponse.error_response(request_id, code, message, data) + return JSONResponse(response.to_dict()) + + async def _stream_jsonrpc_response(self, generator, request_id): + """Wrap async generator for SSE streaming with JSON-RPC format.""" + try: + async for value in generator: + # Wrap each chunk in JSON-RPC format + chunk_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": value, + } + yield self._to_sse(chunk_response) + except Exception as e: + self.logger.exception("Error in async streaming") + error_response = JsonRpcResponse.error_response( + request_id, + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + ) + yield self._to_sse(error_response.to_dict()) + + def _sync_stream_jsonrpc_response(self, generator, request_id): + """Wrap sync generator for SSE streaming with JSON-RPC format.""" + try: + for value in generator: + chunk_response = { + "jsonrpc": "2.0", + "id": request_id, + "result": value, + } + yield self._to_sse(chunk_response) + except Exception as e: + self.logger.exception("Error in sync streaming") + error_response = JsonRpcResponse.error_response( + request_id, + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + ) + yield self._to_sse(error_response.to_dict()) + + def _to_sse(self, data: Any) -> bytes: + """Convert data to SSE format.""" + json_string = json.dumps(data, ensure_ascii=False) + return f"data: {json_string}\n\n".encode("utf-8") + + def _handle_agent_card(self, request): + """Handle GET /.well-known/agent-card.json endpoint.""" + try: + runtime_url = self._get_runtime_url() + card_dict = self.agent_card.to_dict(url=runtime_url) + + self.logger.debug("Serving Agent Card: %s", self.agent_card.name) + return JSONResponse(card_dict) + except Exception as e: + self.logger.exception("Failed to serve Agent Card") + return JSONResponse({"error": str(e)}, status_code=500) + + def _handle_ping(self, request): + """Handle GET /ping health check endpoint.""" + try: + status = self.get_current_ping_status() + self.logger.debug("Ping request - status: %s", status.value) + return JSONResponse({"status": status.value, "time_of_last_update": int(self._last_status_update_time)}) + except Exception: + self.logger.exception("Ping endpoint failed") + return JSONResponse({"status": PingStatus.HEALTHY.value, "time_of_last_update": int(time.time())}) + + async def _invoke_handler(self, handler, request_context, takes_context, jsonrpc_request): + """Invoke the handler with appropriate arguments.""" + try: + args = (jsonrpc_request, request_context) if takes_context else (jsonrpc_request,) + + if asyncio.iscoroutinefunction(handler): + return await handler(*args) + else: + loop = asyncio.get_event_loop() + ctx = contextvars.copy_context() + return await loop.run_in_executor(None, ctx.run, handler, *args) + except Exception: + handler_name = getattr(handler, "__name__", "unknown") + self.logger.debug("Handler '%s' execution failed", handler_name) + raise + + def run(self, port: int = A2A_DEFAULT_PORT, host: Optional[str] = None, **kwargs): + """Start the Bedrock AgentCore A2A server. + + Args: + port: Port to serve on, defaults to 9000 (A2A standard) + host: Host to bind to, auto-detected if None + **kwargs: Additional arguments passed to uvicorn.run() + """ + import uvicorn + + if host is None: + if os.path.exists("/.dockerenv") or os.environ.get("DOCKER_CONTAINER"): + host = "0.0.0.0" # nosec B104 - Docker needs this to expose the port + else: + host = "127.0.0.1" + + uvicorn_params = { + "host": host, + "port": port, + "access_log": self.debug, + "log_level": "info" if self.debug else "warning", + } + uvicorn_params.update(kwargs) + + self.logger.info("Starting A2A server on %s:%d", host, port) + uvicorn.run(self, **uvicorn_params) diff --git a/src/bedrock_agentcore/runtime/a2a_models.py b/src/bedrock_agentcore/runtime/a2a_models.py new file mode 100644 index 0000000..0eebbac --- /dev/null +++ b/src/bedrock_agentcore/runtime/a2a_models.py @@ -0,0 +1,284 @@ +"""Models for Bedrock AgentCore A2A runtime. + +Contains data models for A2A protocol including Agent Card, JSON-RPC 2.0 messages, +and related types. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Union +from urllib.parse import quote + + +class JsonRpcErrorCode(int, Enum): + """Standard JSON-RPC 2.0 error codes and A2A-specific error codes.""" + + # Standard JSON-RPC 2.0 errors + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 + + # A2A-specific error codes (AgentCore Runtime) + RESOURCE_NOT_FOUND = -32501 + VALIDATION_ERROR = -32502 + THROTTLING = -32503 + RESOURCE_CONFLICT = -32504 + RUNTIME_CLIENT_ERROR = -32505 + + +@dataclass +class AgentSkill: + """A2A Agent Skill definition. + + Skills describe specific capabilities that the agent can perform. + """ + + id: str + name: str + description: str + tags: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "tags": self.tags, + } + + +@dataclass +class AgentCard: + """A2A Agent Card metadata. + + Agent Cards describe an agent's identity, capabilities, and how to communicate with it. + This metadata is served at /.well-known/agent-card.json endpoint. + """ + + name: str + description: str + version: str = "1.0.0" + protocol_version: str = "0.3.0" + preferred_transport: str = "JSONRPC" + capabilities: Dict[str, Any] = field(default_factory=lambda: {"streaming": True}) + default_input_modes: List[str] = field(default_factory=lambda: ["text"]) + default_output_modes: List[str] = field(default_factory=lambda: ["text"]) + skills: List[AgentSkill] = field(default_factory=list) + + def to_dict(self, url: Optional[str] = None) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization. + + Args: + url: The URL where this agent is accessible. If not provided, + the 'url' field will be omitted from the output. + + Returns: + Dictionary representation of the Agent Card. + """ + result = { + "name": self.name, + "description": self.description, + "version": self.version, + "protocolVersion": self.protocol_version, + "preferredTransport": self.preferred_transport, + "capabilities": self.capabilities, + "defaultInputModes": self.default_input_modes, + "defaultOutputModes": self.default_output_modes, + "skills": [skill.to_dict() for skill in self.skills], + } + if url: + result["url"] = url + return result + + +@dataclass +class JsonRpcRequest: + """JSON-RPC 2.0 Request object.""" + + method: str + id: Optional[Union[str, int]] = None + params: Optional[Dict[str, Any]] = None + jsonrpc: str = "2.0" + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JsonRpcRequest": + """Create from dictionary.""" + return cls( + jsonrpc=data.get("jsonrpc", "2.0"), + id=data.get("id"), + method=data.get("method", ""), + params=data.get("params"), + ) + + +@dataclass +class JsonRpcError: + """JSON-RPC 2.0 Error object.""" + + code: int + message: str + data: Optional[Any] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result = {"code": self.code, "message": self.message} + if self.data is not None: + result["data"] = self.data + return result + + +@dataclass +class JsonRpcResponse: + """JSON-RPC 2.0 Response object.""" + + id: Optional[Union[str, int]] + result: Optional[Any] = None + error: Optional[JsonRpcError] = None + jsonrpc: str = "2.0" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + response = {"jsonrpc": self.jsonrpc, "id": self.id} + if self.error is not None: + response["error"] = self.error.to_dict() + else: + response["result"] = self.result + return response + + @classmethod + def success(cls, id: Optional[Union[str, int]], result: Any) -> "JsonRpcResponse": + """Create a success response.""" + return cls(id=id, result=result) + + @classmethod + def error_response( + cls, + id: Optional[Union[str, int]], + code: int, + message: str, + data: Optional[Any] = None, + ) -> "JsonRpcResponse": + """Create an error response.""" + return cls(id=id, error=JsonRpcError(code=code, message=message, data=data)) + + +@dataclass +class A2AMessagePart: + """A2A message part (text, file, data, etc.).""" + + kind: str # "text", "file", "data", etc. + text: Optional[str] = None + file: Optional[Dict[str, Any]] = None + data: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: Dict[str, Any] = {"kind": self.kind} + if self.text is not None: + result["text"] = self.text + if self.file is not None: + result["file"] = self.file + if self.data is not None: + result["data"] = self.data + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "A2AMessagePart": + """Create from dictionary.""" + return cls( + kind=data.get("kind", "text"), + text=data.get("text"), + file=data.get("file"), + data=data.get("data"), + ) + + +@dataclass +class A2AMessage: + """A2A protocol message.""" + + role: str # "user", "agent" + parts: List[A2AMessagePart] + message_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: Dict[str, Any] = { + "role": self.role, + "parts": [part.to_dict() for part in self.parts], + } + if self.message_id: + result["messageId"] = self.message_id + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "A2AMessage": + """Create from dictionary.""" + parts = [A2AMessagePart.from_dict(p) for p in data.get("parts", [])] + return cls( + role=data.get("role", "user"), + parts=parts, + message_id=data.get("messageId"), + ) + + def get_text(self) -> str: + """Extract text content from message parts.""" + texts = [] + for part in self.parts: + if part.kind == "text" and part.text: + texts.append(part.text) + return "\n".join(texts) + + +@dataclass +class A2AArtifact: + """A2A protocol artifact (response content).""" + + artifact_id: str + name: str + parts: List[A2AMessagePart] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "artifactId": self.artifact_id, + "name": self.name, + "parts": [part.to_dict() for part in self.parts], + } + + @classmethod + def from_text(cls, artifact_id: str, name: str, text: str) -> "A2AArtifact": + """Create a text artifact.""" + return cls( + artifact_id=artifact_id, + name=name, + parts=[A2AMessagePart(kind="text", text=text)], + ) + + +def build_runtime_url(agent_arn: str, region: str = "us-west-2") -> str: + """Build the AgentCore Runtime URL from an agent ARN. + + Args: + agent_arn: The ARN of the agent runtime + region: AWS region (default: us-west-2) + + Returns: + The full runtime URL with properly encoded ARN + """ + # URL encode the ARN (safe='' means encode all special characters) + escaped_arn = quote(agent_arn, safe="") + return f"https://bedrock-agentcore.{region}.amazonaws.com/runtimes/{escaped_arn}/invocations/" + + +# A2A Protocol Methods +A2A_METHOD_MESSAGE_SEND = "message/send" +A2A_METHOD_MESSAGE_STREAM = "message/stream" +A2A_METHOD_TASKS_GET = "tasks/get" +A2A_METHOD_TASKS_CANCEL = "tasks/cancel" + +# Default A2A port for AgentCore Runtime +A2A_DEFAULT_PORT = 9000 diff --git a/tests/bedrock_agentcore/runtime/test_a2a_app.py b/tests/bedrock_agentcore/runtime/test_a2a_app.py new file mode 100644 index 0000000..ef58cf7 --- /dev/null +++ b/tests/bedrock_agentcore/runtime/test_a2a_app.py @@ -0,0 +1,507 @@ +"""Tests for BedrockAgentCoreA2AApp.""" + +import asyncio +import contextlib +import json +import os +import uuid +from unittest.mock import patch + +import pytest +from starlette.testclient import TestClient + +from bedrock_agentcore.runtime import ( + AgentCard, + AgentSkill, + BedrockAgentCoreA2AApp, + JsonRpcRequest, +) + + +@pytest.fixture +def agent_card(): + """Create a test AgentCard.""" + return AgentCard( + name="Test Agent", + description="A test agent for unit testing", + skills=[ + AgentSkill(id="test", name="Test Skill", description="A test skill"), + ], + ) + + +@pytest.fixture +def app(agent_card): + """Create a test A2A app.""" + return BedrockAgentCoreA2AApp(agent_card=agent_card) + + +class TestBedrockAgentCoreA2AAppInitialization: + def test_basic_initialization(self, agent_card): + """Test basic app initialization.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + assert app.agent_card == agent_card + assert app.handlers == {} + assert app.debug is False + + def test_initialization_with_debug(self, agent_card): + """Test app initialization with debug mode.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card, debug=True) + assert app.debug is True + + def test_routes_registered(self, agent_card): + """Test that required routes are registered.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + route_paths = [route.path for route in app.routes] + assert "/" in route_paths + assert "/.well-known/agent-card.json" in route_paths + assert "/ping" in route_paths + + +class TestAgentCardEndpoint: + def test_agent_card_endpoint(self, app, agent_card): + """Test GET /.well-known/agent-card.json returns agent card.""" + client = TestClient(app) + response = client.get("/.well-known/agent-card.json") + + assert response.status_code == 200 + data = response.json() + assert data["name"] == agent_card.name + assert data["description"] == agent_card.description + assert data["protocolVersion"] == agent_card.protocol_version + assert len(data["skills"]) == 1 + assert data["skills"][0]["id"] == "test" + + def test_agent_card_with_runtime_url(self, agent_card): + """Test agent card includes URL when AGENTCORE_RUNTIME_URL is set.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + with patch.dict(os.environ, {"AGENTCORE_RUNTIME_URL": "https://example.com/agent"}): + client = TestClient(app) + response = client.get("/.well-known/agent-card.json") + + assert response.status_code == 200 + data = response.json() + assert data["url"] == "https://example.com/agent" + + +class TestPingEndpoint: + def test_ping_endpoint(self, app): + """Test GET /ping returns healthy status.""" + client = TestClient(app) + response = client.get("/ping") + + assert response.status_code == 200 + data = response.json() + assert data["status"] in ["Healthy", "HEALTHY"] + assert "time_of_last_update" in data + + def test_custom_ping_handler(self, agent_card): + """Test custom ping handler.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + + @app.ping + def custom_ping(): + return "HealthyBusy" + + client = TestClient(app) + response = client.get("/ping") + + assert response.status_code == 200 + data = response.json() + assert data["status"] in ["HealthyBusy", "HEALTHY_BUSY"] + + +class TestEntrypointDecorator: + def test_entrypoint_decorator(self, app): + """Test @app.entrypoint registers handler.""" + + @app.entrypoint + def handler(request, context): + return {"result": "success"} + + assert "main" in app.handlers + assert app.handlers["main"] == handler + assert hasattr(handler, "run") + + def test_entrypoint_without_context(self, app): + """Test entrypoint handler without context parameter.""" + + @app.entrypoint + def handler(request): + return {"result": request.method} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": {}, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert data["result"]["result"] == "message/send" + + +class TestJsonRpcHandling: + def test_valid_jsonrpc_request(self, app): + """Test valid JSON-RPC request.""" + + @app.entrypoint + def handler(request, context): + return {"artifacts": [{"artifactId": "art-1", "name": "response", "parts": []}]} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + } + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert "result" in data + assert "artifacts" in data["result"] + + def test_invalid_jsonrpc_version(self, app): + """Test invalid JSON-RPC version returns error.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "1.0", # Invalid version + "id": "req-001", + "method": "test", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32600 # Invalid request + + def test_missing_method(self, app): + """Test missing method returns error.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + # Missing method + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32600 # Invalid request + + def test_no_entrypoint_defined(self, app): + """Test error when no entrypoint is defined.""" + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 # Internal error + + def test_invalid_json(self, app): + """Test invalid JSON returns parse error.""" + + @app.entrypoint + def handler(request, context): + return {} + + client = TestClient(app) + response = client.post( + "/", + content="not valid json", + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32700 # Parse error + + def test_handler_exception(self, app): + """Test handler exception returns internal error.""" + + @app.entrypoint + def handler(request, context): + raise ValueError("Test error") + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == -32603 # Internal error + assert "Test error" in data["error"]["message"] + + +class TestAsyncHandler: + def test_async_handler(self, app): + """Test async handler.""" + + @app.entrypoint + async def handler(request, context): + await asyncio.sleep(0.01) + return {"result": "async success"} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["result"]["result"] == "async success" + + +class TestStreamingResponse: + def test_async_generator_response(self, app): + """Test async generator for streaming response.""" + + @app.entrypoint + async def handler(request, context): + async def generate(): + yield {"chunk": 1} + yield {"chunk": 2} + yield {"chunk": 3} + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/stream", + }, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + # Parse SSE events + events = response.text.split("\n\n") + events = [e for e in events if e.strip()] + + assert len(events) == 3 + for i, event in enumerate(events, 1): + assert event.startswith("data: ") + data = json.loads(event[6:]) + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert data["result"]["chunk"] == i + + def test_sync_generator_response(self, app): + """Test sync generator for streaming response.""" + + @app.entrypoint + def handler(request, context): + def generate(): + yield {"part": "A"} + yield {"part": "B"} + + return generate() + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + ) + + assert response.status_code == 200 + events = response.text.split("\n\n") + events = [e for e in events if e.strip()] + assert len(events) == 2 + + +class TestSessionHeader: + def test_session_id_from_header(self, app): + """Test session ID is extracted from header.""" + captured_session_id = None + + @app.entrypoint + def handler(request, context): + nonlocal captured_session_id + captured_session_id = context.session_id + return {"session": context.session_id} + + client = TestClient(app) + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + }, + headers={"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "test-session-123"}, + ) + + assert response.status_code == 200 + assert captured_session_id == "test-session-123" + + +class TestRunMethod: + @patch("uvicorn.run") + def test_run_default_port(self, mock_uvicorn, app): + """Test run uses default A2A port 9000.""" + app.run() + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["port"] == 9000 + assert call_kwargs["host"] == "127.0.0.1" + + @patch("uvicorn.run") + def test_run_custom_port(self, mock_uvicorn, app): + """Test run with custom port.""" + app.run(port=8080) + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["port"] == 8080 + + @patch.dict(os.environ, {"DOCKER_CONTAINER": "true"}) + @patch("uvicorn.run") + def test_run_in_docker(self, mock_uvicorn, agent_card): + """Test run in Docker environment.""" + app = BedrockAgentCoreA2AApp(agent_card=agent_card) + app.run() + + mock_uvicorn.assert_called_once() + call_kwargs = mock_uvicorn.call_args[1] + assert call_kwargs["host"] == "0.0.0.0" + + +class TestLifespan: + def test_lifespan_startup_and_shutdown(self, agent_card): + """Test lifespan startup and shutdown.""" + startup_called = False + shutdown_called = False + + @contextlib.asynccontextmanager + async def lifespan(app): + nonlocal startup_called, shutdown_called + startup_called = True + yield + shutdown_called = True + + app = BedrockAgentCoreA2AApp(agent_card=agent_card, lifespan=lifespan) + + with TestClient(app): + assert startup_called is True + assert shutdown_called is True + + +class TestIntegrationScenario: + def test_full_message_flow(self, app): + """Test complete message flow with A2A protocol.""" + + @app.entrypoint + def handler(request, context): + # Extract message from params + params = request.params or {} + message = params.get("message", {}) + parts = message.get("parts", []) + user_text = "" + for part in parts: + if part.get("kind") == "text": + user_text = part.get("text", "") + break + + # Return A2A formatted response + return { + "artifacts": [ + { + "artifactId": str(uuid.uuid4()), + "name": "agent_response", + "parts": [{"kind": "text", "text": f"Received: {user_text}"}], + } + ] + } + + client = TestClient(app) + + # Send message + response = client.post( + "/", + json={ + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "What is 2 + 2?"}], + "messageId": "msg-001", + } + }, + }, + ) + + assert response.status_code == 200 + data = response.json() + assert data["jsonrpc"] == "2.0" + assert data["id"] == "req-001" + assert "result" in data + assert "artifacts" in data["result"] + assert len(data["result"]["artifacts"]) == 1 + assert "Received: What is 2 + 2?" in data["result"]["artifacts"][0]["parts"][0]["text"] diff --git a/tests/bedrock_agentcore/runtime/test_a2a_models.py b/tests/bedrock_agentcore/runtime/test_a2a_models.py new file mode 100644 index 0000000..9d76f04 --- /dev/null +++ b/tests/bedrock_agentcore/runtime/test_a2a_models.py @@ -0,0 +1,340 @@ +"""Tests for A2A models.""" + +import pytest + +from bedrock_agentcore.runtime.a2a_models import ( + A2A_DEFAULT_PORT, + A2A_METHOD_MESSAGE_SEND, + A2AArtifact, + A2AMessage, + A2AMessagePart, + AgentCard, + AgentSkill, + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + build_runtime_url, +) + + +class TestAgentSkill: + def test_basic_creation(self): + """Test creating a basic AgentSkill.""" + skill = AgentSkill( + id="calc", + name="Calculator", + description="Perform arithmetic calculations", + ) + assert skill.id == "calc" + assert skill.name == "Calculator" + assert skill.description == "Perform arithmetic calculations" + assert skill.tags == [] + + def test_creation_with_tags(self): + """Test creating AgentSkill with tags.""" + skill = AgentSkill( + id="search", + name="Web Search", + description="Search the web", + tags=["search", "web", "information"], + ) + assert skill.tags == ["search", "web", "information"] + + def test_to_dict(self): + """Test AgentSkill serialization to dict.""" + skill = AgentSkill( + id="calc", + name="Calculator", + description="Math operations", + tags=["math"], + ) + result = skill.to_dict() + assert result == { + "id": "calc", + "name": "Calculator", + "description": "Math operations", + "tags": ["math"], + } + + +class TestAgentCard: + def test_basic_creation(self): + """Test creating a basic AgentCard.""" + card = AgentCard( + name="Test Agent", + description="A test agent", + ) + assert card.name == "Test Agent" + assert card.description == "A test agent" + assert card.version == "1.0.0" + assert card.protocol_version == "0.3.0" + assert card.preferred_transport == "JSONRPC" + assert card.capabilities == {"streaming": True} + assert card.default_input_modes == ["text"] + assert card.default_output_modes == ["text"] + assert card.skills == [] + + def test_creation_with_skills(self): + """Test AgentCard with skills.""" + skills = [ + AgentSkill(id="s1", name="Skill 1", description="First skill"), + AgentSkill(id="s2", name="Skill 2", description="Second skill"), + ] + card = AgentCard( + name="Multi-Skill Agent", + description="An agent with multiple skills", + skills=skills, + ) + assert len(card.skills) == 2 + assert card.skills[0].id == "s1" + assert card.skills[1].id == "s2" + + def test_to_dict_without_url(self): + """Test AgentCard serialization without URL.""" + card = AgentCard( + name="Test Agent", + description="A test agent", + ) + result = card.to_dict() + assert result["name"] == "Test Agent" + assert result["description"] == "A test agent" + assert result["protocolVersion"] == "0.3.0" + assert result["preferredTransport"] == "JSONRPC" + assert "url" not in result + + def test_to_dict_with_url(self): + """Test AgentCard serialization with URL.""" + card = AgentCard( + name="Test Agent", + description="A test agent", + ) + result = card.to_dict(url="https://example.com/agent") + assert result["url"] == "https://example.com/agent" + + def test_to_dict_with_skills(self): + """Test AgentCard serialization with skills.""" + skills = [AgentSkill(id="s1", name="Skill 1", description="First skill")] + card = AgentCard( + name="Test Agent", + description="A test agent", + skills=skills, + ) + result = card.to_dict() + assert len(result["skills"]) == 1 + assert result["skills"][0]["id"] == "s1" + + +class TestJsonRpcRequest: + def test_from_dict(self): + """Test creating JsonRpcRequest from dict.""" + data = { + "jsonrpc": "2.0", + "id": "req-001", + "method": "message/send", + "params": {"message": {"text": "Hello"}}, + } + request = JsonRpcRequest.from_dict(data) + assert request.jsonrpc == "2.0" + assert request.id == "req-001" + assert request.method == "message/send" + assert request.params == {"message": {"text": "Hello"}} + + def test_from_dict_minimal(self): + """Test creating JsonRpcRequest with minimal data.""" + data = {"method": "test"} + request = JsonRpcRequest.from_dict(data) + assert request.jsonrpc == "2.0" + assert request.id is None + assert request.method == "test" + assert request.params is None + + +class TestJsonRpcResponse: + def test_success_response(self): + """Test creating a success response.""" + response = JsonRpcResponse.success("req-001", {"result": "success"}) + assert response.id == "req-001" + assert response.result == {"result": "success"} + assert response.error is None + + def test_error_response(self): + """Test creating an error response.""" + response = JsonRpcResponse.error_response( + "req-001", + JsonRpcErrorCode.INTERNAL_ERROR, + "Something went wrong", + ) + assert response.id == "req-001" + assert response.result is None + assert response.error is not None + assert response.error.code == JsonRpcErrorCode.INTERNAL_ERROR + assert response.error.message == "Something went wrong" + + def test_success_to_dict(self): + """Test success response serialization.""" + response = JsonRpcResponse.success("req-001", {"data": "test"}) + result = response.to_dict() + assert result == { + "jsonrpc": "2.0", + "id": "req-001", + "result": {"data": "test"}, + } + + def test_error_to_dict(self): + """Test error response serialization.""" + response = JsonRpcResponse.error_response( + "req-001", + -32600, + "Invalid request", + ) + result = response.to_dict() + assert result == { + "jsonrpc": "2.0", + "id": "req-001", + "error": { + "code": -32600, + "message": "Invalid request", + }, + } + + +class TestA2AMessagePart: + def test_text_part(self): + """Test creating a text message part.""" + part = A2AMessagePart(kind="text", text="Hello, world!") + assert part.kind == "text" + assert part.text == "Hello, world!" + assert part.file is None + assert part.data is None + + def test_to_dict(self): + """Test message part serialization.""" + part = A2AMessagePart(kind="text", text="Test message") + result = part.to_dict() + assert result == {"kind": "text", "text": "Test message"} + + def test_from_dict(self): + """Test creating message part from dict.""" + data = {"kind": "text", "text": "Hello"} + part = A2AMessagePart.from_dict(data) + assert part.kind == "text" + assert part.text == "Hello" + + +class TestA2AMessage: + def test_basic_creation(self): + """Test creating a basic A2A message.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts) + assert message.role == "user" + assert len(message.parts) == 1 + assert message.message_id is None + + def test_with_message_id(self): + """Test message with ID.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts, message_id="msg-001") + assert message.message_id == "msg-001" + + def test_get_text(self): + """Test extracting text from message.""" + parts = [ + A2AMessagePart(kind="text", text="Line 1"), + A2AMessagePart(kind="text", text="Line 2"), + ] + message = A2AMessage(role="user", parts=parts) + assert message.get_text() == "Line 1\nLine 2" + + def test_to_dict(self): + """Test message serialization.""" + parts = [A2AMessagePart(kind="text", text="Hello")] + message = A2AMessage(role="user", parts=parts, message_id="msg-001") + result = message.to_dict() + assert result == { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "msg-001", + } + + def test_from_dict(self): + """Test creating message from dict.""" + data = { + "role": "agent", + "parts": [{"kind": "text", "text": "Response"}], + "messageId": "msg-002", + } + message = A2AMessage.from_dict(data) + assert message.role == "agent" + assert len(message.parts) == 1 + assert message.parts[0].text == "Response" + assert message.message_id == "msg-002" + + +class TestA2AArtifact: + def test_basic_creation(self): + """Test creating a basic artifact.""" + parts = [A2AMessagePart(kind="text", text="Result")] + artifact = A2AArtifact( + artifact_id="art-001", + name="response", + parts=parts, + ) + assert artifact.artifact_id == "art-001" + assert artifact.name == "response" + assert len(artifact.parts) == 1 + + def test_from_text(self): + """Test creating text artifact.""" + artifact = A2AArtifact.from_text("art-001", "response", "Hello") + assert artifact.artifact_id == "art-001" + assert artifact.name == "response" + assert len(artifact.parts) == 1 + assert artifact.parts[0].kind == "text" + assert artifact.parts[0].text == "Hello" + + def test_to_dict(self): + """Test artifact serialization.""" + artifact = A2AArtifact.from_text("art-001", "response", "Result") + result = artifact.to_dict() + assert result == { + "artifactId": "art-001", + "name": "response", + "parts": [{"kind": "text", "text": "Result"}], + } + + +class TestBuildRuntimeUrl: + def test_basic_url(self): + """Test building runtime URL.""" + arn = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/my-agent" + url = build_runtime_url(arn) + # ARN should be URL-encoded + assert "us-west-2" in url + assert "arn%3Aaws%3Abedrock-agentcore" in url + assert url.startswith("https://bedrock-agentcore.us-west-2.amazonaws.com/runtimes/") + assert url.endswith("/invocations/") + + def test_url_with_region(self): + """Test building URL with custom region.""" + arn = "arn:aws:bedrock-agentcore:us-east-1:123456789012:runtime/my-agent" + url = build_runtime_url(arn, region="us-east-1") + assert "us-east-1" in url + assert "bedrock-agentcore.us-east-1.amazonaws.com" in url + + def test_special_characters_encoded(self): + """Test that special characters in ARN are properly encoded.""" + arn = "arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/agent-with-special" + url = build_runtime_url(arn) + # Colon and slash should be encoded + assert "%3A" in url # Encoded colon + assert "%2F" in url # Encoded slash + + +class TestConstants: + def test_default_port(self): + """Test A2A default port.""" + assert A2A_DEFAULT_PORT == 9000 + + def test_method_constants(self): + """Test A2A method constants.""" + assert A2A_METHOD_MESSAGE_SEND == "message/send"