diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..b61c0242 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,191 @@ +version: 2 +updates: + # GitHub Actions + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly + + # Python dependencies — A2A agents + - package-ecosystem: pip + directory: /a2a/a2a_contact_extractor + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/a2a_currency_converter + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/cheerup_agent + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/file_organizer + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/generic_agent + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/git_issue_agent + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/image_service + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/recipe_agent + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/reservation_service + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/simple_generalist + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/slack_researcher + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/trivia_agent + schedule: + interval: weekly + - package-ecosystem: pip + directory: /a2a/weather_service + schedule: + interval: weekly + + # Python dependencies — MCP tools + - package-ecosystem: pip + directory: /mcp/cloud_storage_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/flight_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/image_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/movie_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/reservation_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/shopping_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/slack_tool + schedule: + interval: weekly + - package-ecosystem: pip + directory: /mcp/weather_tool + schedule: + interval: weekly + + # Docker — A2A agents + - package-ecosystem: docker + directory: /a2a/a2a_contact_extractor + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/a2a_currency_converter + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/cheerup_agent + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/file_organizer + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/generic_agent + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/git_issue_agent + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/image_service + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/recipe_agent + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/reservation_service + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/simple_generalist + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/slack_researcher + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/trivia_agent + schedule: + interval: weekly + - package-ecosystem: docker + directory: /a2a/weather_service + schedule: + interval: weekly + + # Docker — MCP tools + - package-ecosystem: docker + directory: /mcp/appworld_apis + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/cloud_storage_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/flight_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/github_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/image_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/movie_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/reservation_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/shopping_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/slack_tool + schedule: + interval: weekly + - package-ecosystem: docker + directory: /mcp/weather_tool + schedule: + interval: weekly diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 5ac0f812..4f6fe774 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -8,12 +8,12 @@ on: # Allows also to run this workflow manually from the Actions tab workflow_dispatch: +permissions: {} + jobs: build-and-push: - # The type of runner that the job will run on runs-on: ubuntu-latest - - # Grant GITHUB_TOKEN the permissions to write packages + timeout-minutes: 30 permissions: contents: read packages: write @@ -64,16 +64,16 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - name: Set up QEMU - uses: docker/setup-qemu-action@v4 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v4 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Log in to ghcr.io - uses: docker/login-action@v4 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4 with: registry: ghcr.io username: ${{ github.actor }} @@ -81,7 +81,7 @@ jobs: - name: Extract Docker metadata for ${{ matrix.image_config.name }} id: meta - uses: docker/metadata-action@v6 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6 with: images: ghcr.io/${{ github.repository }}/${{ matrix.image_config.name }} tags: | @@ -91,7 +91,7 @@ jobs: type=raw,value=latest,enable=${{ github.ref_type == 'tag' }} - name: Build and push ${{ matrix.image_config.name }} - uses: docker/build-push-action@v7 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7 with: context: ${{ matrix.image_config.path }}/${{ matrix.image_config.name }} file: ${{ matrix.image_config.path }}/${{ matrix.image_config.name }}/Dockerfile diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 00d45067..6be0e857 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,45 +1,39 @@ name: CI on: - # Triggers the workflow on push or pull request events but only for the "main" branch pull_request: - branches: [ "main" ] + branches: [main] + push: + branches: [main] + +permissions: + contents: read jobs: - build: + lint: runs-on: ubuntu-latest - - strategy: - matrix: - python-version: ["3.12"] - + timeout-minutes: 10 steps: - - name: Checkout repository - uses: actions/checkout@v6 - - # Sets up a specific version of Python - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 with: - python-version: ${{ matrix.python-version }} - - # Installs dependencies - # It's a good practice to cache dependencies to speed up subsequent runs - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - # Lints with flake8 - # This step checks for style issues in the code - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - # - name: Test with pytest - # run: | - # pytest + python-version: "3.12" + - name: Install ruff + run: pip install ruff==0.11.4 + - name: Lint with ruff + run: ruff check . + - name: Check formatting with ruff + run: ruff format --check . + + test: + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6 + with: + python-version: "3.12" + - name: Install test dependencies + run: pip install pytest pydantic pydantic-settings httpx python-dotenv + - name: Run tests + run: python -m pytest -v diff --git a/.github/workflows/scorecard.yaml b/.github/workflows/scorecard.yaml new file mode 100644 index 00000000..9e91983a --- /dev/null +++ b/.github/workflows/scorecard.yaml @@ -0,0 +1,35 @@ +name: Scorecard + +on: + push: + branches: [main] + schedule: + - cron: "30 6 * * 1" # Weekly Monday 6:30 AM UTC + workflow_dispatch: + +permissions: read-all + +jobs: + analysis: + runs-on: ubuntu-latest + timeout-minutes: 10 + permissions: + security-events: write + id-token: write + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + with: + persist-credentials: false + - uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3 + with: + results_file: results.sarif + results_format: sarif + publish_results: true + - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: scorecard-results + path: results.sarif + retention-days: 30 + - uses: github/codeql-action/upload-sarif@0d579ffd059c29b07949a3cce3983f0780820c98 # v4 + with: + sarif_file: results.sarif diff --git a/.github/workflows/security-scans.yaml b/.github/workflows/security-scans.yaml new file mode 100644 index 00000000..90900b81 --- /dev/null +++ b/.github/workflows/security-scans.yaml @@ -0,0 +1,68 @@ +name: Security Scans + +on: + pull_request: + branches: [main] + +permissions: {} + +jobs: + dependency-review: + runs-on: ubuntu-latest + timeout-minutes: 10 + permissions: + contents: read + pull-requests: write + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: actions/dependency-review-action@2031cfc080254a8a887f58cffee85186f0e49e48 # v4 + with: + fail-on-severity: critical + deny-licenses: GPL-3.0, AGPL-3.0 + + trivy-scan: + runs-on: ubuntu-latest + timeout-minutes: 15 + permissions: + contents: read + security-events: write + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.35.0 + with: + scan-type: fs + scan-ref: . + severity: CRITICAL,HIGH + exit-code: 0 # Informational — upstream dependency CVEs in community examples + format: sarif + output: trivy-results.sarif + - uses: github/codeql-action/upload-sarif@0d579ffd059c29b07949a3cce3983f0780820c98 # v4 + if: always() + with: + sarif_file: trivy-results.sarif + + codeql: + runs-on: ubuntu-latest + timeout-minutes: 15 + permissions: + security-events: write + contents: read + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: github/codeql-action/init@0d579ffd059c29b07949a3cce3983f0780820c98 # v4 + with: + languages: python + queries: security-extended + - uses: github/codeql-action/analyze@0d579ffd059c29b07949a3cce3983f0780820c98 # v4 + + hadolint: + runs-on: ubuntu-latest + timeout-minutes: 5 + permissions: + contents: read + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 + - uses: hadolint/hadolint-action@2332a7b74a6de0dda2e2221d575162eba76ba5e5 # v3.3.0 + with: + recursive: true + failure-threshold: error diff --git a/a2a/a2a_contact_extractor/__main__.py b/a2a/a2a_contact_extractor/__main__.py index c5a5ec75..fa26d49d 100644 --- a/a2a/a2a_contact_extractor/__main__.py +++ b/a2a/a2a_contact_extractor/__main__.py @@ -11,12 +11,11 @@ from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore from a2a.types import AgentCapabilities, AgentCard, AgentSkill -from starlette.routing import Route -from agent import ExtractorAgent from dotenv import load_dotenv from pydantic import BaseModel, EmailStr, Field +from starlette.routing import Route -from agent import ExtractorAgent # type: ignore[import-untyped] +from agent import ExtractorAgent from agent_executor import ExtractorAgentExecutor # type: ignore[import-untyped] load_dotenv() @@ -68,12 +67,15 @@ def main(host, port, result_type, instructions): app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host=host, port=port) diff --git a/a2a/a2a_contact_extractor/agent.py b/a2a/a2a_contact_extractor/agent.py index 0e590e7e..3db85625 100644 --- a/a2a/a2a_contact_extractor/agent.py +++ b/a2a/a2a_contact_extractor/agent.py @@ -2,19 +2,18 @@ import os import threading from collections.abc import AsyncIterable -from typing import Annotated, Any, ClassVar - -#from common.types import TextPart -from pydantic import BaseModel, Field -from typing import Annotated, Any, Literal +from typing import Annotated, Any, ClassVar, Literal import marvin +# from common.types import TextPart +from pydantic import BaseModel, Field + logger = logging.getLogger(__name__) class TextPart(BaseModel): - type: Literal['text'] = 'text' + type: Literal["text"] = "text" text: str metadata: dict[str, Any] | None = None diff --git a/a2a/a2a_contact_extractor/agent_executor.py b/a2a/a2a_contact_extractor/agent_executor.py index 897c9d53..3ae5e220 100644 --- a/a2a/a2a_contact_extractor/agent_executor.py +++ b/a2a/a2a_contact_extractor/agent_executor.py @@ -41,8 +41,8 @@ async def execute( async for item in self.agent.stream(query, task.contextId): is_task_complete = item["is_task_complete"] require_user_input = item["require_user_input"] - #content = item["content"] - content = item.get('content', '') + # content = item["content"] + content = item.get("content", "") logger.info( f"Stream item received: complete={is_task_complete}, require_input={require_user_input}, content_len={len(content)}" @@ -58,7 +58,6 @@ async def execute( # Extract the text from each TextPart object content = " ".join(part.text for part in content) - artifact = new_text_artifact( name="current_result", description="Result of request to agent.", diff --git a/a2a/a2a_currency_converter/Dockerfile b/a2a/a2a_currency_converter/Dockerfile index edb05147..b35fcc25 100644 --- a/a2a/a2a_currency_converter/Dockerfile +++ b/a2a/a2a_currency_converter/Dockerfile @@ -1,7 +1,7 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim ARG RELEASE_VERSION="main" -ADD . /app +COPY . /app WORKDIR /app RUN uv sync --no-cache --locked --link-mode copy diff --git a/a2a/a2a_currency_converter/app/__main__.py b/a2a/a2a_currency_converter/app/__main__.py index 33b2499b..9a3403fe 100644 --- a/a2a/a2a_currency_converter/app/__main__.py +++ b/a2a/a2a_currency_converter/app/__main__.py @@ -5,22 +5,20 @@ import click import httpx import uvicorn +from dotenv import load_dotenv +from starlette.routing import Route from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryPushNotifier, InMemoryTaskStore -from starlette.routing import Route from a2a.types import ( AgentCapabilities, AgentCard, AgentSkill, ) -from dotenv import load_dotenv - from app.agent import CurrencyAgent from app.agent_executor import CurrencyAgentExecutor - load_dotenv() logging.basicConfig(level=logging.INFO) @@ -32,8 +30,8 @@ class MissingAPIKeyError(Exception): @click.command() -@click.option('--host', 'host', default='localhost') -@click.option('--port', 'port', default=10000) +@click.option("--host", "host", default="localhost") +@click.option("--port", "port", default=10000) def main(host, port): """Starts the Currency Agent server.""" try: @@ -41,18 +39,18 @@ def main(host, port): capabilities = AgentCapabilities(streaming=True, pushNotifications=True) skill = AgentSkill( - id='convert_currency', - name='Currency Exchange Rates Tool', - description='Helps with exchange values between various currencies', - tags=['currency conversion', 'currency exchange'], - examples=['What is exchange rate between USD and GBP?'], + id="convert_currency", + name="Currency Exchange Rates Tool", + description="Helps with exchange values between various currencies", + tags=["currency conversion", "currency exchange"], + examples=["What is exchange rate between USD and GBP?"], ) agent_card = AgentCard( - name='Currency Agent', - description='Helps with exchange rates for currencies', + name="Currency Agent", + description="Helps with exchange rates for currencies", # Allow env var AGENT_ENDPOINT to override the URL in the agent card - url=os.getenv("AGENT_ENDPOINT", f'http://{host}:{port}/'), - version='1.0.0', + url=os.getenv("AGENT_ENDPOINT", f"http://{host}:{port}/"), + version="1.0.0", defaultInputModes=CurrencyAgent.SUPPORTED_CONTENT_TYPES, defaultOutputModes=CurrencyAgent.SUPPORTED_CONTENT_TYPES, capabilities=capabilities, @@ -66,30 +64,31 @@ def main(host, port): task_store=InMemoryTaskStore(), push_notifier=InMemoryPushNotifier(httpx_client), ) - server = A2AStarletteApplication( - agent_card=agent_card, http_handler=request_handler - ) + server = A2AStarletteApplication(agent_card=agent_card, http_handler=request_handler) app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host=host, port=port) # --8<-- [end:DefaultRequestHandler] except MissingAPIKeyError as e: - logger.error(f'Error: {e}') + logger.error(f"Error: {e}") sys.exit(1) except Exception as e: - logger.error(f'An error occurred during server startup: {e}') + logger.error(f"An error occurred during server startup: {e}") sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/a2a/a2a_currency_converter/app/agent.py b/a2a/a2a_currency_converter/app/agent.py index ba6c761c..f859d06e 100644 --- a/a2a/a2a_currency_converter/app/agent.py +++ b/a2a/a2a_currency_converter/app/agent.py @@ -1,24 +1,21 @@ """Currency conversion logic for A2A example""" import os - from collections.abc import AsyncIterable from typing import Any, Literal import httpx - from langchain_core.messages import AIMessage, ToolMessage from langchain_core.tools import tool -from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from pydantic import BaseModel from pydantic_settings import BaseSettings -from langchain_openai import ChatOpenAI - memory = MemorySaver() + class Configuration(BaseSettings): """The configuration of the Agent""" @@ -34,9 +31,9 @@ class Configuration(BaseSettings): @tool def get_exchange_rate( - currency_from: str = 'USD', - currency_to: str = 'EUR', - currency_date: str = 'latest', + currency_from: str = "USD", + currency_to: str = "EUR", + currency_date: str = "latest", ): """Use this to get current exchange rate. @@ -52,25 +49,25 @@ def get_exchange_rate( """ try: response = httpx.get( - f'https://api.frankfurter.app/{currency_date}', - params={'from': currency_from, 'to': currency_to}, + f"https://api.frankfurter.app/{currency_date}", + params={"from": currency_from, "to": currency_to}, ) response.raise_for_status() data = response.json() - if 'rates' not in data: - return {'error': 'Invalid API response format.'} + if "rates" not in data: + return {"error": "Invalid API response format."} return data except httpx.HTTPError as e: - return {'error': f'API request failed: {e}'} + return {"error": f"API request failed: {e}"} except ValueError: - return {'error': 'Invalid JSON response from API.'} + return {"error": "Invalid JSON response from API."} class ResponseFormat(BaseModel): """Respond to the user in this format.""" - status: Literal['input_required', 'completed', 'error'] = 'input_required' + status: Literal["input_required", "completed", "error"] = "input_required" message: str @@ -78,18 +75,18 @@ class CurrencyAgent: """CurrencyAgent - a specialized assistant for currency conversions.""" SYSTEM_INSTRUCTION = ( - 'You are a specialized assistant for currency conversions. ' + "You are a specialized assistant for currency conversions. " "Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates. " - 'If the user asks about anything other than currency conversion or exchange rates, ' - 'politely state that you cannot help with that topic and can only assist with currency-related queries. ' - 'Do not attempt to answer unrelated questions or use tools for other purposes.' - 'Set response status to input_required if the user needs to provide more information.' - 'Set response status to error if there is an error while processing the request.' - 'Set response status to completed if the request is complete.' + "If the user asks about anything other than currency conversion or exchange rates, " + "politely state that you cannot help with that topic and can only assist with currency-related queries. " + "Do not attempt to answer unrelated questions or use tools for other purposes." + "Set response status to input_required if the user needs to provide more information." + "Set response status to error if there is an error while processing the request." + "Set response status to completed if the request is complete." ) def __init__(self): - # self.model = ChatGoogleGenerativeAI(model='gemini-2.0-flash') + # self.model = ChatGoogleGenerativeAI(model='gemini-2.0-flash') self.model = ChatOpenAI( model=config.llm_model, openai_api_key=config.llm_api_key, @@ -107,67 +104,58 @@ def __init__(self): ) def invoke(self, query, context_id) -> str: - config = {'configurable': {'thread_id': context_id}} - self.graph.invoke({'messages': [('user', query)]}, config) + config = {"configurable": {"thread_id": context_id}} + self.graph.invoke({"messages": [("user", query)]}, config) return self.get_agent_response(config) async def stream(self, query, context_id) -> AsyncIterable[dict[str, Any]]: - inputs = {'messages': [('user', query)]} - config = {'configurable': {'thread_id': context_id}} - - for item in self.graph.stream(inputs, config, stream_mode='values'): - message = item['messages'][-1] - if ( - isinstance(message, AIMessage) - and message.tool_calls - and len(message.tool_calls) > 0 - ): + inputs = {"messages": [("user", query)]} + config = {"configurable": {"thread_id": context_id}} + + for item in self.graph.stream(inputs, config, stream_mode="values"): + message = item["messages"][-1] + if isinstance(message, AIMessage) and message.tool_calls and len(message.tool_calls) > 0: yield { - 'is_task_complete': False, - 'require_user_input': False, - 'content': 'Looking up the exchange rates...', + "is_task_complete": False, + "require_user_input": False, + "content": "Looking up the exchange rates...", } elif isinstance(message, ToolMessage): yield { - 'is_task_complete': False, - 'require_user_input': False, - 'content': 'Processing the exchange rates..', + "is_task_complete": False, + "require_user_input": False, + "content": "Processing the exchange rates..", } yield self.get_agent_response(config) def get_agent_response(self, config): current_state = self.graph.get_state(config) - structured_response = current_state.values.get('structured_response') - if structured_response and isinstance( - structured_response, ResponseFormat - ): - if structured_response.status == 'input_required': + structured_response = current_state.values.get("structured_response") + if structured_response and isinstance(structured_response, ResponseFormat): + if structured_response.status == "input_required": return { - 'is_task_complete': False, - 'require_user_input': True, - 'content': structured_response.message, + "is_task_complete": False, + "require_user_input": True, + "content": structured_response.message, } - if structured_response.status == 'error': + if structured_response.status == "error": return { - 'is_task_complete': False, - 'require_user_input': True, - 'content': structured_response.message, + "is_task_complete": False, + "require_user_input": True, + "content": structured_response.message, } - if structured_response.status == 'completed': + if structured_response.status == "completed": return { - 'is_task_complete': True, - 'require_user_input': False, - 'content': structured_response.message, + "is_task_complete": True, + "require_user_input": False, + "content": structured_response.message, } return { - 'is_task_complete': False, - 'require_user_input': True, - 'content': ( - 'We are unable to process your request at the moment. ' - 'Please try again.' - ), + "is_task_complete": False, + "require_user_input": True, + "content": ("We are unable to process your request at the moment. Please try again."), } - SUPPORTED_CONTENT_TYPES = ['text', 'text/plain'] + SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] diff --git a/a2a/a2a_currency_converter/app/agent_executor.py b/a2a/a2a_currency_converter/app/agent_executor.py index 17ea43ce..999477b8 100644 --- a/a2a/a2a_currency_converter/app/agent_executor.py +++ b/a2a/a2a_currency_converter/app/agent_executor.py @@ -5,8 +5,6 @@ import logging import os -from app.agent import CurrencyAgent - from openai import AuthenticationError, InternalServerError from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -26,6 +24,7 @@ new_task, ) from a2a.utils.errors import ServerError +from app.agent import CurrencyAgent logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -44,37 +43,37 @@ async def execute( ) -> None: error = self._validate_request(context) if error: - logger.warning(f'Invalid agent executor request: {context}') + logger.warning(f"Invalid agent executor request: {context}") raise ServerError(error=InvalidParamsError()) query = context.get_user_input() task = context.current_task if not task: task = new_task(context.message) - logger.info(f'Created task for message : {context.message}') + logger.info(f"Created task for message : {context.message}") event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.contextId) try: async for item in self.agent.stream(query, task.contextId): - is_task_complete = item['is_task_complete'] - require_user_input = item['require_user_input'] + is_task_complete = item["is_task_complete"] + require_user_input = item["require_user_input"] if not is_task_complete and not require_user_input: - logger.info(f'Updating status for non-input task: {task.id}') + logger.info(f"Updating status for non-input task: {task.id}") updater.update_status( TaskState.working, new_agent_text_message( - item['content'], + item["content"], task.contextId, task.id, ), ) elif require_user_input: - logger.info(f'Updating status for input task: {task.id}') + logger.info(f"Updating status for input task: {task.id}") updater.update_status( TaskState.input_required, new_agent_text_message( - item['content'], + item["content"], task.contextId, task.id, ), @@ -82,16 +81,16 @@ async def execute( ) break else: - logger.info('Adding artifact for item') + logger.info("Adding artifact for item") updater.add_artifact( - [Part(root=TextPart(text=item['content']))], - name='conversion_result', + [Part(root=TextPart(text=item["content"]))], + name="conversion_result", ) updater.complete() break except InternalServerError as e: - msg=f"""CurrencyAgentExecutor reports an InternalServerError error. + msg = f"""CurrencyAgentExecutor reports an InternalServerError error. This can happen if the agent's LLM_API_BASE environment variable does not point to an OpenAI server. @@ -112,7 +111,7 @@ async def execute( ) except AuthenticationError as e: - msg=f"""CurrencyAgentExecutor reports an authentication error. + msg = f"""CurrencyAgentExecutor reports an authentication error. When importing this agent into Kagenti, expand Environment Variables and Add Variable, or import https://github.com/kagenti/agent-examples/blob/main/a2a/a2a_currency_converter/.env.openai @@ -120,7 +119,7 @@ async def execute( Use `kubectl -n logs deployment/` for details. Also check -`kubectl -n get secret openai-secret -o jsonpath="{'{'}.data.apikey{'}'}" | base64 -d` +`kubectl -n get secret openai-secret -o jsonpath="{"{"}.data.apikey{"}"}" | base64 -d` The key should match your OpenAI key.""" logger.error(msg=msg) logger.error(msg=f"Raw AuthenticationError {e}") @@ -135,8 +134,8 @@ async def execute( ) except Exception as e: - logger.error(f'An error occurred while streaming the response: {e}') - logger.info(msg=f'The error is a {type(e)}') + logger.error(f"An error occurred while streaming the response: {e}") + logger.info(msg=f"The error is a {type(e)}") updater.update_status( TaskState.input_required, new_agent_text_message( @@ -153,7 +152,5 @@ async def execute( def _validate_request(self, _: RequestContext) -> bool: return False - async def cancel( - self, _: RequestContext, event_queue: EventQueue - ) -> Task | None: + async def cancel(self, _: RequestContext, event_queue: EventQueue) -> Task | None: raise ServerError(error=UnsupportedOperationError()) diff --git a/a2a/a2a_currency_converter/app/test_client.py b/a2a/a2a_currency_converter/app/test_client.py index 1840662f..d9970630 100644 --- a/a2a/a2a_currency_converter/app/test_client.py +++ b/a2a/a2a_currency_converter/app/test_client.py @@ -1,5 +1,4 @@ import logging - from typing import Any from uuid import uuid4 @@ -9,15 +8,13 @@ from a2a.types import ( AgentCard, MessageSendParams, - SendMessageRequest, SendStreamingMessageRequest, ) - -PUBLIC_AGENT_CARD_PATH = '/.well-known/agent.json' -EXTENDED_AGENT_CARD_PATH = '/agent/authenticatedExtendedCard' -#USER_INPUT = 'how much is 10 USD in INR?' -USER_INPUT = 'My name is John Doe, email: john@example.com, phone: (555) 123-4567' +PUBLIC_AGENT_CARD_PATH = "/.well-known/agent.json" +EXTENDED_AGENT_CARD_PATH = "/agent/authenticatedExtendedCard" +# USER_INPUT = 'how much is 10 USD in INR?' +USER_INPUT = "My name is John Doe, email: john@example.com, phone: (555) 123-4567" async def main() -> None: @@ -27,7 +24,7 @@ async def main() -> None: # --8<-- [start:A2ACardResolver] - #base_url = 'http://localhost:9000' + # base_url = 'http://localhost:9000' base_url = "http://a2a-contact-extractor.localtest.me:8080" async with httpx.AsyncClient() as httpx_client: @@ -40,88 +37,57 @@ async def main() -> None: # --8<-- [end:A2ACardResolver] # Fetch Public Agent Card and Initialize Client - final_agent_card_to_use: AgentCard | None = None + _final_agent_card_to_use: AgentCard | None = None try: - logger.info( - f'Attempting to fetch public agent card from: {base_url}{PUBLIC_AGENT_CARD_PATH}' - ) - _public_card = ( - await resolver.get_agent_card() - ) # Fetches from default public path - logger.info('Successfully fetched public agent card:') - logger.info( - _public_card.model_dump_json(indent=2, exclude_none=True) - ) - final_agent_card_to_use = _public_card - logger.info( - '\nUsing PUBLIC agent card for client initialization (default).' - ) + logger.info(f"Attempting to fetch public agent card from: {base_url}{PUBLIC_AGENT_CARD_PATH}") + _public_card = await resolver.get_agent_card() # Fetches from default public path + logger.info("Successfully fetched public agent card:") + logger.info(_public_card.model_dump_json(indent=2, exclude_none=True)) + _final_agent_card_to_use = _public_card + logger.info("\nUsing PUBLIC agent card for client initialization (default).") if _public_card.supportsAuthenticatedExtendedCard: try: logger.info( - '\nPublic card supports authenticated extended card. ' - 'Attempting to fetch from: ' - f'{base_url}{EXTENDED_AGENT_CARD_PATH}' + "\nPublic card supports authenticated extended card. " + "Attempting to fetch from: " + f"{base_url}{EXTENDED_AGENT_CARD_PATH}" ) - auth_headers_dict = { - 'Authorization': 'Bearer dummy-token-for-extended-card' - } + auth_headers_dict = {"Authorization": "Bearer dummy-token-for-extended-card"} _extended_card = await resolver.get_agent_card( relative_card_path=EXTENDED_AGENT_CARD_PATH, - http_kwargs={'headers': auth_headers_dict}, - ) - logger.info( - 'Successfully fetched authenticated extended agent card:' - ) - logger.info( - _extended_card.model_dump_json( - indent=2, exclude_none=True - ) - ) - final_agent_card_to_use = ( - _extended_card # Update to use the extended card - ) - logger.info( - '\nUsing AUTHENTICATED EXTENDED agent card for client ' - 'initialization.' + http_kwargs={"headers": auth_headers_dict}, ) + logger.info("Successfully fetched authenticated extended agent card:") + logger.info(_extended_card.model_dump_json(indent=2, exclude_none=True)) + _final_agent_card_to_use = _extended_card # Update to use the extended card + logger.info("\nUsing AUTHENTICATED EXTENDED agent card for client initialization.") except Exception as e_extended: logger.warning( - f'Failed to fetch extended agent card: {e_extended}. ' - 'Will proceed with public card.', + f"Failed to fetch extended agent card: {e_extended}. Will proceed with public card.", exc_info=True, ) - elif ( - _public_card - ): # supportsAuthenticatedExtendedCard is False or None - logger.info( - '\nPublic card does not indicate support for an extended card. Using public card.' - ) + elif _public_card: # supportsAuthenticatedExtendedCard is False or None + logger.info("\nPublic card does not indicate support for an extended card. Using public card.") except Exception as e: - logger.error( - f'Critical error fetching public agent card: {e}', exc_info=True - ) - raise RuntimeError( - 'Failed to fetch the public agent card. Cannot continue.' - ) from e + logger.error(f"Critical error fetching public agent card: {e}", exc_info=True) + raise RuntimeError("Failed to fetch the public agent card. Cannot continue.") from e # --8<-- [start:send_message] client = A2AClient( - #httpx_client=httpx_client, agent_card=final_agent_card_to_use - httpx_client=httpx_client,url=base_url + # httpx_client=httpx_client, agent_card=final_agent_card_to_use + httpx_client=httpx_client, + url=base_url, ) - logger.info('A2AClient initialized.') + logger.info("A2AClient initialized.") send_message_payload: dict[str, Any] = { - 'message': { - 'role': 'user', - 'parts': [ - {'kind': 'text', 'text': USER_INPUT} - ], - 'messageId': uuid4().hex, + "message": { + "role": "user", + "parts": [{"kind": "text", "text": USER_INPUT}], + "messageId": uuid4().hex, }, } # request = SendMessageRequest( @@ -141,11 +107,11 @@ async def main() -> None: stream_response = client.send_message_streaming(streaming_request) async for chunk in stream_response: - print(chunk.model_dump(mode='json', exclude_none=True)) + print(chunk.model_dump(mode="json", exclude_none=True)) # --8<-- [end:send_message_streaming] -if __name__ == '__main__': +if __name__ == "__main__": import asyncio asyncio.run(main()) diff --git a/a2a/cheerup_agent/src/cheerup_agent/__init__.py b/a2a/cheerup_agent/src/cheerup_agent/__init__.py index 5318a82e..bf78fec7 100644 --- a/a2a/cheerup_agent/src/cheerup_agent/__init__.py +++ b/a2a/cheerup_agent/src/cheerup_agent/__init__.py @@ -1,16 +1,20 @@ -from opentelemetry.sdk.resources import Resource from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + def setup_tracer(): - resource = Resource.create(attributes={ - "service.name": "cheerup-agent", - }) + resource = Resource.create( + attributes={ + "service.name": "cheerup-agent", + } + ) provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(OTLPSpanExporter()) provider.add_span_processor(processor) trace.set_tracer_provider(provider) + setup_tracer() diff --git a/a2a/cheerup_agent/src/cheerup_agent/agent.py b/a2a/cheerup_agent/src/cheerup_agent/agent.py index 8ddf8cd8..6aac7db2 100644 --- a/a2a/cheerup_agent/src/cheerup_agent/agent.py +++ b/a2a/cheerup_agent/src/cheerup_agent/agent.py @@ -1,10 +1,10 @@ import logging +from textwrap import dedent import uvicorn from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route -from textwrap import dedent from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication @@ -13,7 +13,6 @@ from a2a.server.tasks import InMemoryTaskStore, TaskUpdater from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task - from cheerup_agent.cheerup_llm import chat logging.basicConfig(level=logging.DEBUG) diff --git a/a2a/cheerup_agent/src/cheerup_agent/configuration.py b/a2a/cheerup_agent/src/cheerup_agent/configuration.py index d837750a..5b009f42 100644 --- a/a2a/cheerup_agent/src/cheerup_agent/configuration.py +++ b/a2a/cheerup_agent/src/cheerup_agent/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "qwen3:4b" llm_api_base: str = "http://host.docker.internal:11434/v1" diff --git a/a2a/file_organizer/src/file_organizer/__init__.py b/a2a/file_organizer/src/file_organizer/__init__.py index e48584ee..40a1f70d 100644 --- a/a2a/file_organizer/src/file_organizer/__init__.py +++ b/a2a/file_organizer/src/file_organizer/__init__.py @@ -1,16 +1,20 @@ -from opentelemetry.sdk.resources import Resource from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + def setup_tracer(): - resource = Resource.create(attributes={ - "service.name": "a2a-server", - }) + resource = Resource.create( + attributes={ + "service.name": "a2a-server", + } + ) provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(OTLPSpanExporter()) provider.add_span_processor(processor) trace.set_tracer_provider(provider) + setup_tracer() diff --git a/a2a/file_organizer/src/file_organizer/agent.py b/a2a/file_organizer/src/file_organizer/agent.py index e4878a3d..2bc3a487 100644 --- a/a2a/file_organizer/src/file_organizer/agent.py +++ b/a2a/file_organizer/src/file_organizer/agent.py @@ -1,20 +1,19 @@ -import json import logging import os -import uvicorn from textwrap import dedent +import uvicorn +from langchain_core.messages import HumanMessage +from openinference.instrumentation.langchain import LangChainInstrumentor +from starlette.routing import Route + from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater -from starlette.routing import Route from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task -from openinference.instrumentation.langchain import LangChainInstrumentor -from langchain_core.messages import HumanMessage - from file_organizer.graph import get_graph, get_mcpclient logging.basicConfig(level=logging.DEBUG) @@ -22,6 +21,7 @@ LangChainInstrumentor().instrument() + def get_agent_card(host: str, port: int): """Returns the Agent Card for the A2A Agent.""" capabilities = AgentCapabilities(streaming=True) @@ -56,6 +56,7 @@ def get_agent_card(host: str, port: int): skills=[skill], ) + class A2AEvent: """ A class to handle events for A2A Agent. @@ -87,10 +88,12 @@ async def emit_event(self, message: str, final: bool = False, failed: bool = Fal ), ) + class FileOrganizerExecutor(AgentExecutor): """ A class to handle file organizer execution for A2A Agent. """ + async def execute(self, context: RequestContext, event_queue: EventQueue): """ The agent allows to organize files through a natural language conversational interface @@ -108,22 +111,25 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() messages = [HumanMessage(content=user_input)] input_data = {"messages": messages} - logger.info(f'Processing messages: {input_data}') + logger.info(f"Processing messages: {input_data}") try: # Test MCP connection first - logger.info(f'Attempting to connect to MCP server at: {os.getenv("MCP_URL", "http://localhost:8000/sse")}') + logger.info(f"Attempting to connect to MCP server at: {os.getenv('MCP_URL', 'http://localhost:8000/sse')}") mcpclient = get_mcpclient() # Try to get tools to verify connection try: tools = await mcpclient.get_tools() - logger.info(f'Successfully connected to MCP server. Available tools: {[tool.name for tool in tools]}') + logger.info(f"Successfully connected to MCP server. Available tools: {[tool.name for tool in tools]}") except Exception as tool_error: - logger.error(f'Failed to connect to MCP server: {tool_error}') - await event_emitter.emit_event(f"Error: Cannot connect to MCP cloud storage at {os.getenv('MCP_URL', 'http://localhost:8000/sse')}. Please ensure the cloud storage MCP server is running. Error: {tool_error}", failed=True) + logger.error(f"Failed to connect to MCP server: {tool_error}") + await event_emitter.emit_event( + f"Error: Cannot connect to MCP cloud storage at {os.getenv('MCP_URL', 'http://localhost:8000/sse')}. Please ensure the cloud storage MCP server is running. Error: {tool_error}", + failed=True, + ) return graph = await get_graph(mcpclient) @@ -137,16 +143,19 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): + "\n" ) output = event - logger.info(f'event: {event}') - + logger.info(f"event: {event}") + if output: final_answer = output.get("assistant", {}).get("final_answer", "File organization completed.") await event_emitter.emit_event(str(final_answer), final=True) else: await event_emitter.emit_event("File organization completed.", final=True) except Exception as e: - logger.error(f'Graph execution error: {e}') - await event_emitter.emit_event(f"Error: Failed to process file organization request. {str(e)}", failed=True) + logger.error(f"Graph execution error: {e}") + await event_emitter.emit_event( + f"Error: Failed to process file organization request. {str(e)}", + failed=True, + ) raise Exception(str(e)) async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: @@ -155,6 +164,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """ raise Exception("cancel not supported") + def run(): """ Runs the A2A Agent application. @@ -174,11 +184,14 @@ def run(): app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/a2a/file_organizer/src/file_organizer/configuration.py b/a2a/file_organizer/src/file_organizer/configuration.py index 0c7dcb66..dd1c9f4c 100644 --- a/a2a/file_organizer/src/file_organizer/configuration.py +++ b/a2a/file_organizer/src/file_organizer/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "llama3.1" llm_api_base: str = "http://localhost:11434/v1" diff --git a/a2a/file_organizer/src/file_organizer/graph.py b/a2a/file_organizer/src/file_organizer/graph.py index d0f2a2da..f16c5d77 100644 --- a/a2a/file_organizer/src/file_organizer/graph.py +++ b/a2a/file_organizer/src/file_organizer/graph.py @@ -1,26 +1,31 @@ import os -from langgraph.graph import StateGraph, MessagesState, START + +from langchain_core.messages import AIMessage, SystemMessage from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_core.messages import SystemMessage, AIMessage -from langgraph.prebuilt import tools_condition, ToolNode from langchain_openai import ChatOpenAI +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition from file_organizer.configuration import Configuration config = Configuration() + # Extend MessagesState to include a final answer class ExtendedMessagesState(MessagesState): - final_answer: str = "" + final_answer: str = "" + def get_mcpclient(): - - return MultiServerMCPClient({ - "cloud_storage": { - "url": os.getenv("MCP_URL", "http://cloud-storage-tool:8000/mcp"), - "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + return MultiServerMCPClient( + { + "cloud_storage": { + "url": os.getenv("MCP_URL", "http://cloud-storage-tool:8000/mcp"), + "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + } } - }) + ) + async def get_graph(client) -> StateGraph: llm = ChatOpenAI( @@ -35,9 +40,14 @@ async def get_graph(client) -> StateGraph: llm_with_tools = llm.bind_tools(tools) bucket_uri = os.getenv("BUCKET_URI") - bucket_info = f"Target bucket: {bucket_uri}" if bucket_uri else "No bucket URI configured. Ask the user to specify which bucket to organize." + bucket_info = ( + f"Target bucket: {bucket_uri}" + if bucket_uri + else "No bucket URI configured. Ask the user to specify which bucket to organize." + ) - sys_msg = SystemMessage(content=f"""You are a file organization assistant for cloud storage buckets. + sys_msg = SystemMessage( + content=f"""You are a file organization assistant for cloud storage buckets. {bucket_info} @@ -50,7 +60,8 @@ async def get_graph(client) -> StateGraph: - Logical grouping (similar file types together) 4. Use the perform_action tool to move the object as needed 5. Provide a summary of what you did -""") +""" + ) # Node def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: @@ -76,4 +87,4 @@ def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: # Compile graph graph = builder.compile() - return graph \ No newline at end of file + return graph diff --git a/a2a/generic_agent/src/generic_agent/__init__.py b/a2a/generic_agent/src/generic_agent/__init__.py index e48584ee..40a1f70d 100644 --- a/a2a/generic_agent/src/generic_agent/__init__.py +++ b/a2a/generic_agent/src/generic_agent/__init__.py @@ -1,16 +1,20 @@ -from opentelemetry.sdk.resources import Resource from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + def setup_tracer(): - resource = Resource.create(attributes={ - "service.name": "a2a-server", - }) + resource = Resource.create( + attributes={ + "service.name": "a2a-server", + } + ) provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(OTLPSpanExporter()) provider.add_span_processor(processor) trace.set_tracer_provider(provider) + setup_tracer() diff --git a/a2a/generic_agent/src/generic_agent/agent.py b/a2a/generic_agent/src/generic_agent/agent.py index 9838bb53..dc61343c 100644 --- a/a2a/generic_agent/src/generic_agent/agent.py +++ b/a2a/generic_agent/src/generic_agent/agent.py @@ -1,20 +1,20 @@ import logging -import uvicorn from textwrap import dedent +import uvicorn +from langchain_core.messages import HumanMessage +from openinference.instrumentation.langchain import LangChainInstrumentor +from starlette.routing import Route + from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater -from starlette.routing import Route from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task -from openinference.instrumentation.langchain import LangChainInstrumentor -from langchain_core.messages import HumanMessage - -from generic_agent.graph import get_graph, get_mcpclient, get_mcp_server_names from generic_agent.config import Configuration +from generic_agent.graph import get_graph, get_mcp_server_names, get_mcpclient logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -22,6 +22,7 @@ LangChainInstrumentor().instrument() config = Configuration() + def get_agent_card(host: str, port: int) -> AgentCard: """Returns the Agent Card for the A2A Agent.""" try: @@ -32,8 +33,7 @@ def get_agent_card(host: str, port: int) -> AgentCard: mcp_section = "" if mcp_names: mcp_section = "\n\nConnected MCP Servers:\n" + "\n".join(f"- {name}" for name in mcp_names) - - + capabilities = AgentCapabilities(streaming=True) skill = AgentSkill( id="generic_agent", @@ -57,6 +57,7 @@ def get_agent_card(host: str, port: int) -> AgentCard: skills=[skill], ) + class A2AEvent: """ A class to handle events for A2A Agent. @@ -71,12 +72,12 @@ def __init__(self, task_updater: TaskUpdater): async def emit_event(self, message: str, final: bool = False, failed: bool = False) -> None: """ Emit an event to update task status. - + Args: message: The message content to emit final: If True, marks the task as complete failed: If True, marks the task as failed - + Raises: Exception: If event emission fails """ @@ -99,10 +100,12 @@ async def emit_event(self, message: str, final: bool = False, failed: bool = Fal ), ) + class GenericExecutor(AgentExecutor): """ A class to handle generic assistant execution for A2A Agent. """ + async def execute(self, context: RequestContext, event_queue: EventQueue): """ The agent completes tasks through a natural language conversational interface @@ -115,7 +118,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): await event_queue.enqueue_event(task) task_updater = TaskUpdater(event_queue, task.id, task.context_id) event_emitter = A2AEvent(task_updater) - + user_input = context.get_user_input() if not user_input or not user_input.strip(): await event_emitter.emit_event("Error: Empty input provided", failed=True) @@ -124,35 +127,40 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): # Parse Messages messages = [HumanMessage(content=user_input)] input = {"messages": messages} - logger.info(f'Processing messages: {input}') + logger.info(f"Processing messages: {input}") try: output = None # Test MCP connection first - logger.info(f'Attempting to connect to MCP server(s) at: {config.MCP_URLS}') + logger.info(f"Attempting to connect to MCP server(s) at: {config.MCP_URLS}") mcpclient = get_mcpclient() # Try to get tools to verify connection try: tools = await mcpclient.get_tools() - logger.info(f'Successfully connected to MCP server(s). Available tools: {[tool.name for tool in tools]}') + logger.info( + f"Successfully connected to MCP server(s). Available tools: {[tool.name for tool in tools]}" + ) except Exception as tool_error: - logger.error(f'Failed to connect to MCP server(s): {tool_error}') - await event_emitter.emit_event(f"Error: Cannot connect to MCP server(s) at {config.MCP_URLS}. Please ensure the MCP server(s) are running. Error: {tool_error}", failed=True) + logger.error(f"Failed to connect to MCP server(s): {tool_error}") + await event_emitter.emit_event( + f"Error: Cannot connect to MCP server(s) at {config.MCP_URLS}. Please ensure the MCP server(s) are running. Error: {tool_error}", + failed=True, + ) return graph = await get_graph(mcpclient) async for event in graph.astream(input, stream_mode="updates"): await event_emitter.emit_event( "\n".join( - f"🚶‍♂️{key}: {str(value)[:config.MAX_EVENT_DISPLAY_LENGTH] + '...' if len(str(value)) > config.MAX_EVENT_DISPLAY_LENGTH else str(value)}" + f"🚶‍♂️{key}: {str(value)[: config.MAX_EVENT_DISPLAY_LENGTH] + '...' if len(str(value)) > config.MAX_EVENT_DISPLAY_LENGTH else str(value)}" for key, value in event.items() ) + "\n" ) output = event - logger.info(f'event: {event}') + logger.info(f"event: {event}") final_answer = output.get("assistant", {}).get("final_answer") if output else None if final_answer is None: @@ -161,7 +169,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): else: await event_emitter.emit_event(str(final_answer), final=True) except Exception as e: - logger.error(f'Graph execution error: {e}') + logger.error(f"Graph execution error: {e}") await event_emitter.emit_event(f"Error: Failed to process request. {str(e)}", failed=True) raise Exception(str(e)) @@ -171,6 +179,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """ raise Exception("cancel not supported") + def run(): """ Runs the A2A Agent application. @@ -190,11 +199,14 @@ def run(): app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/a2a/generic_agent/src/generic_agent/config.py b/a2a/generic_agent/src/generic_agent/config.py index 123ae89c..00222bf6 100644 --- a/a2a/generic_agent/src/generic_agent/config.py +++ b/a2a/generic_agent/src/generic_agent/config.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): LLM_MODEL: str = "llama3.2:3b-instruct-fp16" LLM_API_BASE: str = "http://localhost:11434/v1" @@ -7,4 +8,4 @@ class Configuration(BaseSettings): MCP_URLS: str = "http://localhost:8000/mcp" MCP_TRANSPORT: str = "streamable_http" MAX_EVENT_DISPLAY_LENGTH: int = 256 - AGENT_VERSION: str = "1.0.0" \ No newline at end of file + AGENT_VERSION: str = "1.0.0" diff --git a/a2a/generic_agent/src/generic_agent/graph.py b/a2a/generic_agent/src/generic_agent/graph.py index d8b2c5cc..17cce51f 100644 --- a/a2a/generic_agent/src/generic_agent/graph.py +++ b/a2a/generic_agent/src/generic_agent/graph.py @@ -1,31 +1,35 @@ -from langgraph.graph import StateGraph, MessagesState, START -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_core.messages import SystemMessage, AIMessage -from langgraph.prebuilt import tools_condition, ToolNode -from langchain_openai import ChatOpenAI from functools import lru_cache from typing import List +from langchain_core.messages import AIMessage, SystemMessage +from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_openai import ChatOpenAI +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + from generic_agent.config import Configuration config = Configuration() + # Extend MessagesState to include a final answer class ExtendedMessagesState(MessagesState): final_answer: str = "" + def _get_mcp_urls() -> List[str]: """Helper function to parse MCP URLs from environment variable.""" urls_str = config.MCP_URLS - return [url.strip() for url in urls_str.split(',') if url.strip()] + return [url.strip() for url in urls_str.split(",") if url.strip()] + @lru_cache(maxsize=1) def get_mcpclient() -> MultiServerMCPClient: urls = _get_mcp_urls() - + client_configs = {} transport = config.MCP_TRANSPORT - + for i, url in enumerate(urls, 1): client_configs[f"mcp{i}"] = { "url": url, @@ -33,10 +37,11 @@ def get_mcpclient() -> MultiServerMCPClient: } return MultiServerMCPClient(client_configs) + def get_mcp_server_names() -> List[str]: """ Extract MCP server names from URLs. - + Strips protocol (http/https), port, and path to get just the host names. Example: "http://weather-tool:8000/mcp" -> "weather-tool" @@ -48,16 +53,16 @@ def get_mcp_server_names() -> List[str]: mcp_names = [] for url in urls: # Remove protocol - name = url.replace('http://', '').replace('https://', '') + name = url.replace("http://", "").replace("https://", "") # Remove port and path (everything after first :) - name = name.split(':')[0] + name = name.split(":")[0] # Remove /mcp or any path - name = name.split('/')[0] + name = name.split("/")[0] if name: mcp_names.append(name) - + return mcp_names - + async def get_graph(client: MultiServerMCPClient) -> StateGraph: llm = ChatOpenAI( @@ -73,22 +78,20 @@ async def get_graph(client: MultiServerMCPClient) -> StateGraph: # System message sys_msg = SystemMessage( - content="You are the **Generic Assistant**, a multi-purpose, tool-based expert. Your primary directive is to fulfill user requests by effectively utilizing the available **MCP tools**. You will select the most appropriate tool(s) based on the user's need (e.g., weather, calculations, data retrieval) and strictly adhere to their output to generate your final answer. Be precise and concise." -) + content="You are the **Generic Assistant**, a multi-purpose, tool-based expert. Your primary directive is to fulfill user requests by effectively utilizing the available **MCP tools**. You will select the most appropriate tool(s) based on the user's need (e.g., weather, calculations, data retrieval) and strictly adhere to their output to generate your final answer. Be precise and concise." + ) # Node def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: result = llm_with_tools.invoke([sys_msg] + state["messages"]) - - updated_state = { - "messages": state["messages"] + [result] - } - + + updated_state = {"messages": state["messages"] + [result]} + # Set final_answer when LLM returns a text response (not a tool call) # This indicates the assistant has completed its reasoning and tool usage if isinstance(result, AIMessage) and not result.tool_calls: updated_state["final_answer"] = result.content - + return updated_state # Build graph @@ -104,4 +107,4 @@ def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: # Compile graph graph = builder.compile() - return graph \ No newline at end of file + return graph diff --git a/a2a/git_issue_agent/a2a_agent.py b/a2a/git_issue_agent/a2a_agent.py index d5ab2533..c0b3c085 100644 --- a/a2a/git_issue_agent/a2a_agent.py +++ b/a2a/git_issue_agent/a2a_agent.py @@ -5,21 +5,26 @@ import logging import sys import traceback -from typing import Callable import uvicorn from crewai_tools import MCPServerAdapter from crewai_tools.adapters.tool_collection import ToolCollection -from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater -from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart, SecurityScheme, HTTPAuthSecurityScheme +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + TaskState, + TextPart, + SecurityScheme, + HTTPAuthSecurityScheme, +) from a2a.utils import new_agent_text_message, new_task from starlette.routing import Route @@ -29,7 +34,8 @@ from git_issue_agent.main import GitIssueAgent logger = logging.getLogger(__name__) -logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format="%(levelname)s: %(message)s") + def get_agent_card(host: str, port: int): """Returns the Agent Card for the AG2 Agent.""" @@ -56,10 +62,7 @@ def get_agent_card(host: str, port: int): securitySchemes={ "Bearer": SecurityScheme( root=HTTPAuthSecurityScheme( - type="http", - scheme="bearer", - bearerFormat="JWT", - description="OAuth 2.0 JWT token" + type="http", scheme="bearer", bearerFormat="JWT", description="OAuth 2.0 JWT token" ) ) }, @@ -112,12 +115,8 @@ class GithubExecutor(AgentExecutor): """ A class to handle research execution for A2A Agent. """ - async def _run_agent(self, - messages: dict, - settings: Settings, - event_emitter: Event, - toolkit: ToolCollection): + async def _run_agent(self, messages: dict, settings: Settings, event_emitter: Event, toolkit: ToolCollection): git_issue_agent = GitIssueAgent( config=settings, eventer=event_emitter, @@ -142,10 +141,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): headers = {} if settings.GITHUB_TOKEN: headers["Authorization"] = f"Bearer {settings.GITHUB_TOKEN}" - elif context.call_context and (context.call_context.state or {}).get('headers', {}).get('authorization'): - headers["Authorization"] = context.call_context.state['headers']['authorization'] + elif context.call_context and (context.call_context.state or {}).get("headers", {}).get("authorization"): + headers["Authorization"] = context.call_context.state["headers"]["authorization"] else: - logging.warning("No GITHUB_TOKEN or inbound Authorization header; outbound requests will be unauthenticated") + logging.warning( + "No GITHUB_TOKEN or inbound Authorization header; outbound requests will be unauthenticated" + ) user_input = [context.get_user_input()] task = context.current_task @@ -178,8 +179,8 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): issue_tools = [ tool for tool in mcp_tools - if ("issue" in tool.name.lower() or "label" in tool.name.lower()) and - ("search" in tool.name.lower() or "list" in tool.name.lower()) + if ("issue" in tool.name.lower() or "label" in tool.name.lower()) + and ("search" in tool.name.lower() or "list" in tool.name.lower()) ] if not issue_tools: @@ -189,13 +190,13 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): ) await self._run_agent(messages, settings, event_emitter, issue_tools) else: - await self._run_agent(messages, settings, - event_emitter, - None) + await self._run_agent(messages, settings, event_emitter, None) except Exception as e: traceback.print_exc() - await event_emitter.emit_event(f"I'm sorry I was unable to fulfill your request. I encountered the following exception: {str(e)}", True) + await event_emitter.emit_event( + f"I'm sorry I was unable to fulfill your request. I encountered the following exception: {str(e)}", True + ) async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """ @@ -223,11 +224,14 @@ def run(): app = server.build() # this returns a Starlette app # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host="0.0.0.0", port=settings.SERVICE_PORT) diff --git a/a2a/git_issue_agent/git_issue_agent/agents.py b/a2a/git_issue_agent/git_issue_agent/agents.py index dcd8be27..8de15ec9 100644 --- a/a2a/git_issue_agent/git_issue_agent/agents.py +++ b/a2a/git_issue_agent/git_issue_agent/agents.py @@ -3,8 +3,9 @@ from git_issue_agent.data_types import IssueSearchInfo from git_issue_agent.llm import CrewLLM from git_issue_agent.prompts import TOOL_CALL_PROMPT, INFO_PARSER_PROMPT -class GitAgents(): - + + +class GitAgents: def __init__(self, config: Settings, issue_tools): self.llm = CrewLLM(config) @@ -16,25 +17,21 @@ def __init__(self, config: Settings, issue_tools): goal="To extract the information about github artifacts that a user is looking for", backstory=INFO_PARSER_PROMPT, verbose=True, - llm=self.llm.llm + llm=self.llm.llm, ) self.prereq_identifier_task = Task( - description=( - "User query: {request}" - ), + description=("User query: {request}"), agent=self.prereq_identifier, output_pydantic=IssueSearchInfo, - expected_output=( - "A pydantic object representing the extracted relevant information." - ), + expected_output=("A pydantic object representing the extracted relevant information."), ) self.prereq_id_crew = Crew( agents=[self.prereq_identifier], tasks=[self.prereq_identifier_task], process=Process.sequential, - verbose=True, + verbose=True, ) ################### @@ -51,9 +48,9 @@ def __init__(self, config: Settings, issue_tools): verbose=True, llm=self.llm.llm, inject_date=True, - max_iter=6 + max_iter=6, ) - + # --- A generic task template ------------------------------------------------- # The agent will use MCP tools to fulfill natural-language queries. self.issue_query_task = Task( @@ -75,4 +72,4 @@ def __init__(self, config: Settings, issue_tools): tasks=[self.issue_query_task], process=Process.sequential, verbose=True, - ) \ No newline at end of file + ) diff --git a/a2a/git_issue_agent/git_issue_agent/config.py b/a2a/git_issue_agent/git_issue_agent/config.py index 2496dead..39c6748e 100644 --- a/a2a/git_issue_agent/git_issue_agent/config.py +++ b/a2a/git_issue_agent/git_issue_agent/config.py @@ -1,13 +1,13 @@ import json import os -import sys from pydantic_settings import BaseSettings from pydantic import model_validator from pydantic import Field from typing import Literal, Optional + class Settings(BaseSettings): - LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = Field( + LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( os.getenv("LOG_LEVEL", "INFO"), description="Application log level", ) @@ -26,9 +26,14 @@ class Settings(BaseSettings): description="The temperature for the model", ge=0, ) - MCP_URL: str = Field(os.getenv("MCP_URL", "https://api.githubcopilot.com/mcp/"), description="Endpoint for an option MCP server") + MCP_URL: str = Field( + os.getenv("MCP_URL", "https://api.githubcopilot.com/mcp/"), description="Endpoint for an option MCP server" + ) SERVICE_PORT: int = Field(os.getenv("SERVICE_PORT", 8000), description="Port on which the service will run.") - GITHUB_TOKEN: Optional[str] = Field(os.getenv("GITHUB_TOKEN", None), description="If not using agent with authorization, the default Github token to use") + GITHUB_TOKEN: Optional[str] = Field( + os.getenv("GITHUB_TOKEN", None), + description="If not using agent with authorization, the default Github token to use", + ) class Config: env_file = ".env" @@ -44,4 +49,5 @@ def validate_extra_headers(self) -> "Settings": return self + settings = Settings() # type: ignore[call-arg] diff --git a/a2a/git_issue_agent/git_issue_agent/data_types.py b/a2a/git_issue_agent/git_issue_agent/data_types.py index a20c67a9..ca55bcd7 100644 --- a/a2a/git_issue_agent/git_issue_agent/data_types.py +++ b/a2a/git_issue_agent/git_issue_agent/data_types.py @@ -1,11 +1,13 @@ from pydantic import BaseModel, Field -from typing import Literal, Optional ############ # Pydantic types for LLM response formats ############ + class IssueSearchInfo(BaseModel): owner: str = Field(None, description="The issue owner or organization.") repo: str = Field(None, description="The specified repository. Leave blank if none specified.") - issue_numbers: list[int] = Field(None, description="Specific issue number(s) mentioned by the user. If none mentioned leave blank.") \ No newline at end of file + issue_numbers: list[int] = Field( + None, description="Specific issue number(s) mentioned by the user. If none mentioned leave blank." + ) diff --git a/a2a/git_issue_agent/git_issue_agent/llm.py b/a2a/git_issue_agent/git_issue_agent/llm.py index 73fb0fcb..e2169a0e 100644 --- a/a2a/git_issue_agent/git_issue_agent/llm.py +++ b/a2a/git_issue_agent/git_issue_agent/llm.py @@ -1,12 +1,16 @@ from crewai import LLM from git_issue_agent.config import Settings -class CrewLLM(): - def __init__(self, config: Settings): +class CrewLLM: + def __init__(self, config: Settings): self.llm = LLM( model=config.TASK_MODEL_ID, base_url=config.LLM_API_BASE, api_key=config.LLM_API_KEY, - **({'extra_headers': config.EXTRA_HEADERS} if config.EXTRA_HEADERS is not None and None not in config.EXTRA_HEADERS else {}) - ) \ No newline at end of file + **( + {"extra_headers": config.EXTRA_HEADERS} + if config.EXTRA_HEADERS is not None and None not in config.EXTRA_HEADERS + else {} + ), + ) diff --git a/a2a/git_issue_agent/git_issue_agent/main.py b/a2a/git_issue_agent/git_issue_agent/main.py index 0417fbd0..311927d2 100644 --- a/a2a/git_issue_agent/git_issue_agent/main.py +++ b/a2a/git_issue_agent/git_issue_agent/main.py @@ -1,7 +1,4 @@ -from dataclasses import dataclass -from crewai_tools import MCPServerAdapter from crewai_tools.adapters.tool_collection import ToolCollection -from typing import Callable import logging import sys @@ -11,15 +8,17 @@ from git_issue_agent.agents import GitAgents logger = logging.getLogger(__name__) -logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format="%(levelname)s: %(message)s") class GitIssueAgent: - def __init__(self, config: Settings, + def __init__( + self, + config: Settings, eventer: Event = None, mcp_toolkit: ToolCollection = None, - logger=None,): - + logger=None, + ): self.agents = GitAgents(settings, mcp_toolkit) self.eventer = eventer self.logger = logger or logging.getLogger(__name__) @@ -30,7 +29,7 @@ async def _send_event(self, message: str, final: bool = False): await self.eventer.emit_event(message, final) else: logger.warning("No event handler registered") - + def extract_user_input(self, body): content = body[-1]["content"] latest_content = "" @@ -49,11 +48,9 @@ def extract_user_input(self, body): async def execute(self, user_input): query = self.extract_user_input(user_input) await self._send_event("🧐 Evaluating requirements...") - await self.agents.prereq_id_crew.kickoff_async( - inputs={"request": query, "repo": "", "owner": "", "issues": []} - ) + await self.agents.prereq_id_crew.kickoff_async(inputs={"request": query, "repo": "", "owner": "", "issues": []}) repo_id_task_output = self.agents.prereq_identifier_task.output.pydantic - + if repo_id_task_output.issue_numbers: if not repo_id_task_output.owner or not repo_id_task_output.repo: return "When supplying issue numbers, you must provide both a repository name and owner." @@ -62,6 +59,12 @@ async def execute(self, user_input): return "When supplying a repository name, you must also provide an owner of the repo." await self._send_event("🔎 Searching for issues...") - await self.agents.crew.kickoff_async(inputs={"request": query, "owner": repo_id_task_output.owner, "repo": repo_id_task_output.repo, "issues": repo_id_task_output.issue_numbers}) + await self.agents.crew.kickoff_async( + inputs={ + "request": query, + "owner": repo_id_task_output.owner, + "repo": repo_id_task_output.repo, + "issues": repo_id_task_output.issue_numbers, + } + ) return self.agents.issue_query_task.output.raw - diff --git a/a2a/image_service/src/image_service/__init__.py b/a2a/image_service/src/image_service/__init__.py index 4b53b595..5186b4fb 100644 --- a/a2a/image_service/src/image_service/__init__.py +++ b/a2a/image_service/src/image_service/__init__.py @@ -1,16 +1,20 @@ -from opentelemetry.sdk.resources import Resource from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + def setup_tracer(): - resource = Resource.create(attributes={ - "service.name": "image-service", - }) + resource = Resource.create( + attributes={ + "service.name": "image-service", + } + ) provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(OTLPSpanExporter()) provider.add_span_processor(processor) trace.set_tracer_provider(provider) + setup_tracer() diff --git a/a2a/image_service/src/image_service/agent.py b/a2a/image_service/src/image_service/agent.py index cb823db1..402ace01 100644 --- a/a2a/image_service/src/image_service/agent.py +++ b/a2a/image_service/src/image_service/agent.py @@ -2,19 +2,26 @@ import logging import os from textwrap import dedent + import uvicorn +from langchain_core.messages import HumanMessage +from openinference.instrumentation.langchain import LangChainInstrumentor +from starlette.routing import Route from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater -from starlette.routing import Route -from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart, DataPart +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + DataPart, + TaskState, + TextPart, +) from a2a.utils import new_agent_text_message, new_task -from openinference.instrumentation.langchain import LangChainInstrumentor -from langchain_core.messages import HumanMessage - from image_service.graph import get_graph, get_mcpclient logging.basicConfig(level=logging.DEBUG) @@ -22,6 +29,7 @@ LangChainInstrumentor().instrument() + def get_agent_card(host: str, port: int): """Returns the Agent Card for the Image Agent.""" capabilities = AgentCapabilities(streaming=True) @@ -82,22 +90,31 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): """Fetch an image (base64) from the MCP image_tool and return it to the UI.""" task = context.current_task if not task: - task = new_task(context.message) + task = new_task(context.message) await event_queue.enqueue_event(task) task_updater = TaskUpdater(event_queue, task.id, task.context_id) event_emitter = ImageTaskEventEmitter(task_updater) try: # Test MCP connection first - logger.info('Attempting to connect to MCP server at: %s', os.getenv("MCP_URL", "http://localhost:8000/mcp")) + logger.info( + "Attempting to connect to MCP server at: %s", + os.getenv("MCP_URL", "http://localhost:8000/mcp"), + ) mcpclient = get_mcpclient() # Try to get tools to verify connection try: tools = await mcpclient.get_tools() - logger.info('Successfully connected to MCP server. Available tools: %s', [tool.name for tool in tools]) + logger.info( + "Successfully connected to MCP server. Available tools: %s", + [tool.name for tool in tools], + ) except Exception as tool_error: - logger.error('Failed to connect to MCP server: %s', tool_error) - await event_emitter.emit_event(f"Error: Cannot connect to MCP image service at {os.getenv('MCP_URL', 'http://localhost:8000/mcp')}. Please ensure the image MCP server is running. Error: {tool_error}", failed=True) + logger.error("Failed to connect to MCP server: %s", tool_error) + await event_emitter.emit_event( + f"Error: Cannot connect to MCP image service at {os.getenv('MCP_URL', 'http://localhost:8000/mcp')}. Please ensure the image MCP server is running. Error: {tool_error}", + failed=True, + ) return graph = await get_graph(mcpclient) @@ -118,8 +135,8 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): err_msg = "No events were produced by the graph stream; cannot process result." logger.error(err_msg) await event_emitter.emit_event(err_msg, failed=True) - return - + return + result = output.get("assistant", {}).get("final_answer") if not result: @@ -160,12 +177,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): await task_updater.add_artifact(parts, name="image.png") await task_updater.complete() return - + # Fallback: treat as text if not result or (isinstance(result, str) and result.strip() == ""): await event_emitter.emit_event( "I am here to help with image requests. Please ask for an image with specific dimensions.", - final=True + final=True, ) else: await event_emitter.emit_event(str(result), final=True) @@ -176,14 +193,18 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): return except Exception as e: - logger.exception('Graph execution error') - await event_emitter.emit_event(f"Error: Failed to process image request. {type(e).__name__}: {str(e)}", failed=True) + logger.exception("Graph execution error") + await event_emitter.emit_event( + f"Error: Failed to process image request. {type(e).__name__}: {str(e)}", + failed=True, + ) return async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """Not implemented""" raise NotImplementedError("cancel not supported") + def run(): agent_card = get_agent_card(host="0.0.0.0", port=8000) @@ -200,11 +221,14 @@ def run(): app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/a2a/image_service/src/image_service/configuration.py b/a2a/image_service/src/image_service/configuration.py index 0c7dcb66..dd1c9f4c 100644 --- a/a2a/image_service/src/image_service/configuration.py +++ b/a2a/image_service/src/image_service/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "llama3.1" llm_api_base: str = "http://localhost:11434/v1" diff --git a/a2a/image_service/src/image_service/graph.py b/a2a/image_service/src/image_service/graph.py index 001df59d..53cc720f 100644 --- a/a2a/image_service/src/image_service/graph.py +++ b/a2a/image_service/src/image_service/graph.py @@ -1,27 +1,34 @@ -from langgraph.graph import StateGraph, MessagesState, START, END -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_core.messages import SystemMessage, ToolMessage, AIMessage +import json +import os from textwrap import dedent -from langgraph.prebuilt import tools_condition, ToolNode +from typing import Optional + +from langchain_core.messages import AIMessage, SystemMessage, ToolMessage +from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_openai import ChatOpenAI -import os -import json +from langgraph.graph import END, START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + from image_service.configuration import Configuration -from typing import Optional config = Configuration() + # Extend MessagesState to include a final answer class ExtendedMessagesState(MessagesState): final_answer: Optional[dict] = None + def get_mcpclient(): - return MultiServerMCPClient({ - "image": { - "url": os.getenv("MCP_URL", "http://localhost:8000/mcp"), - "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + return MultiServerMCPClient( + { + "image": { + "url": os.getenv("MCP_URL", "http://localhost:8000/mcp"), + "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + } } - }) + ) + async def get_graph(client) -> StateGraph: llm = ChatOpenAI( @@ -36,26 +43,28 @@ async def get_graph(client) -> StateGraph: llm_with_tools = llm.bind_tools(tools) # System message - sys_msg = SystemMessage(content=dedent( - """\ + sys_msg = SystemMessage( + content=dedent( + """\ You are a helpful assistant. Only call the get_image tool when the user EXPLICITLY asks for an image with specific dimensions (e.g., 'show me an image', 'generate an image 400x400', 'image 200 300'). For any conversation that does NOT explicitly request an image, respond directly with text. DO NOT call any tools for these cases. When you do call get_image, you MUST provide valid positive integers for both height and width parameters. - """) + """ + ) ) # Node def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: result = llm_with_tools.invoke([sys_msg] + state["messages"]) state["messages"].append(result) - + if isinstance(result, AIMessage) and not result.tool_calls: state["final_answer"] = {"raw": result.content} new_messages = state["messages"] + [result] # Find the most recent ToolMessage and set its content as final_answer. # NOTE: Only the most recent ToolMessage is processed intentionally. - # If multiple tools are called in sequence, earlier tool results are + # If multiple tools are called in sequence, earlier tool results are # intermediate steps, while the final ToolMessage represents the complete # answer to return to the user. The graph ends once final_answer is set. final_answer = state.get("final_answer") @@ -75,7 +84,7 @@ def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: except Exception as e: final_answer = { "error": "Failed to process tool result", - "details": str(e) + "details": str(e), } break @@ -90,12 +99,12 @@ def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: "assistant", tools_condition, ) - + # After tools run, check if we have final_answer, if so END, otherwise go back to assistant def should_continue(state: ExtendedMessagesState): # End the graph once a final_answer (tool result) is captured return END if state.get("final_answer") is not None else "assistant" - + builder.add_conditional_edges( "tools", should_continue, @@ -103,4 +112,4 @@ def should_continue(state: ExtendedMessagesState): # Compile and return graph graph = builder.compile() - return graph \ No newline at end of file + return graph diff --git a/a2a/recipe_agent/src/recipe_agent/__init__.py b/a2a/recipe_agent/src/recipe_agent/__init__.py index 6fa663ce..66765224 100644 --- a/a2a/recipe_agent/src/recipe_agent/__init__.py +++ b/a2a/recipe_agent/src/recipe_agent/__init__.py @@ -1,16 +1,20 @@ -from opentelemetry.sdk.resources import Resource from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + def setup_tracer(): - resource = Resource.create(attributes={ - "service.name": "recipe-agent", - }) + resource = Resource.create( + attributes={ + "service.name": "recipe-agent", + } + ) provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(OTLPSpanExporter()) provider.add_span_processor(processor) trace.set_tracer_provider(provider) + setup_tracer() diff --git a/a2a/recipe_agent/src/recipe_agent/agent.py b/a2a/recipe_agent/src/recipe_agent/agent.py index 208bd00f..72760665 100644 --- a/a2a/recipe_agent/src/recipe_agent/agent.py +++ b/a2a/recipe_agent/src/recipe_agent/agent.py @@ -1,10 +1,10 @@ import logging +from textwrap import dedent import uvicorn from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route -from textwrap import dedent from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication @@ -13,7 +13,6 @@ from a2a.server.tasks import InMemoryTaskStore, TaskUpdater from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task - from recipe_agent.recipe_llm import chat logging.basicConfig(level=logging.DEBUG) diff --git a/a2a/recipe_agent/src/recipe_agent/configuration.py b/a2a/recipe_agent/src/recipe_agent/configuration.py index de054c6e..026d3ac0 100644 --- a/a2a/recipe_agent/src/recipe_agent/configuration.py +++ b/a2a/recipe_agent/src/recipe_agent/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "qwen2.5:3b" llm_api_base: str = "http://host.docker.internal:11434/v1" diff --git a/a2a/reservation_service/src/reservation_service/agent.py b/a2a/reservation_service/src/reservation_service/agent.py index 6314e7d9..b118923c 100644 --- a/a2a/reservation_service/src/reservation_service/agent.py +++ b/a2a/reservation_service/src/reservation_service/agent.py @@ -1,19 +1,19 @@ import logging import os -import uvicorn from textwrap import dedent +import uvicorn +from langchain_core.messages import HumanMessage +from openinference.instrumentation.langchain import LangChainInstrumentor +from starlette.routing import Route + from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater -from starlette.routing import Route from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task -from openinference.instrumentation.langchain import LangChainInstrumentor -from langchain_core.messages import HumanMessage - from reservation_service.graph import get_graph, get_mcpclient logging.basicConfig(level=logging.DEBUG) @@ -68,6 +68,7 @@ def get_agent_card(host: str, port: int): skills=[skill], ) + class A2AEvent: """ A class to handle events for A2A Agent. @@ -99,10 +100,12 @@ async def emit_event(self, message: str, final: bool = False, failed: bool = Fal ), ) + class ReservationExecutor(AgentExecutor): """ A class to handle reservation assistant execution for A2A Agent. """ + async def execute(self, context: RequestContext, event_queue: EventQueue): """ The agent allows restaurant reservations through a natural language conversational interface @@ -119,26 +122,26 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): # Parse Messages messages = [HumanMessage(content=context.get_user_input())] input = {"messages": messages} - logger.info(f'Processing messages: {input}') + logger.info(f"Processing messages: {input}") try: output = None # Test MCP connection first mcp_url = os.getenv("MCP_URL", "http://reservation-tool:8000/mcp") - logger.info(f'Attempting to connect to MCP server at: {mcp_url}') + logger.info(f"Attempting to connect to MCP server at: {mcp_url}") mcpclient = get_mcpclient() # Try to get tools to verify connection try: tools = await mcpclient.get_tools() - logger.info(f'Successfully connected to MCP server. Available tools: {[tool.name for tool in tools]}') + logger.info(f"Successfully connected to MCP server. Available tools: {[tool.name for tool in tools]}") except Exception as tool_error: - logger.error(f'Failed to connect to MCP server: {tool_error}') + logger.error(f"Failed to connect to MCP server: {tool_error}") await event_emitter.emit_event( f"Error: Cannot connect to reservation MCP service at {mcp_url}. " f"Please ensure the reservation MCP server is running. Error: {tool_error}", - failed=True + failed=True, ) return @@ -152,14 +155,14 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): + "\n" ) output = event - logger.info(f'event: {event}') + logger.info(f"event: {event}") if output is not None: final_answer = output.get("assistant", {}).get("final_answer") await event_emitter.emit_event(str(final_answer), final=True) else: await event_emitter.emit_event("No events produced by the graph.", final=True) except Exception as e: - logger.error(f'Graph execution error: {e}') + logger.error(f"Graph execution error: {e}") await event_emitter.emit_event(f"Error: Failed to process reservation request. {str(e)}", failed=True) raise Exception(str(e)) @@ -169,6 +172,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """ raise Exception("cancel not supported") + def run(): """ Runs the A2A Agent application. @@ -188,11 +192,14 @@ def run(): app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/a2a/reservation_service/src/reservation_service/configuration.py b/a2a/reservation_service/src/reservation_service/configuration.py index 1579c232..db9dcc17 100644 --- a/a2a/reservation_service/src/reservation_service/configuration.py +++ b/a2a/reservation_service/src/reservation_service/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "llama3.2:3b-instruct-fp16" llm_api_base: str = "http://host.docker.internal:11434/v1" diff --git a/a2a/reservation_service/src/reservation_service/graph.py b/a2a/reservation_service/src/reservation_service/graph.py index 1374820f..4e535b5a 100644 --- a/a2a/reservation_service/src/reservation_service/graph.py +++ b/a2a/reservation_service/src/reservation_service/graph.py @@ -1,24 +1,31 @@ -from langgraph.graph import StateGraph, MessagesState, START +import os + +from langchain_core.messages import AIMessage, SystemMessage from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_core.messages import SystemMessage, AIMessage -from langgraph.prebuilt import tools_condition, ToolNode from langchain_openai import ChatOpenAI -import os +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + from reservation_service.configuration import Configuration config = Configuration() + # Extend MessagesState to include a final answer class ExtendedMessagesState(MessagesState): final_answer: str = "" + def get_mcpclient(): - return MultiServerMCPClient({ - "reservations": { - "url": os.getenv("MCP_URL", "http://reservation-tool:8000/mcp"), - "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + return MultiServerMCPClient( + { + "reservations": { + "url": os.getenv("MCP_URL", "http://reservation-tool:8000/mcp"), + "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + } } - }) + ) + async def get_graph(client) -> StateGraph: llm = ChatOpenAI( @@ -33,7 +40,8 @@ async def get_graph(client) -> StateGraph: llm_with_tools = llm.bind_tools(tools) # System message - sys_msg = SystemMessage(content="""You are a helpful restaurant reservation assistant. You have access to tools for: + sys_msg = SystemMessage( + content="""You are a helpful restaurant reservation assistant. You have access to tools for: - Searching restaurants by city, cuisine, price tier - Checking availability at restaurants - Making reservations @@ -47,7 +55,8 @@ async def get_graph(client) -> StateGraph: 4. Provide confirmation codes when reservations are successful 5. Be conversational and helpful -Use the provided tools to complete your tasks.""") +Use the provided tools to complete your tasks.""" + ) # Node def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: diff --git a/a2a/reservation_service/test_agent.py b/a2a/reservation_service/test_agent.py index 764b5993..d25974c0 100644 --- a/a2a/reservation_service/test_agent.py +++ b/a2a/reservation_service/test_agent.py @@ -2,40 +2,32 @@ # -*- coding: utf-8 -*- """Simple test client for the reservation agent.""" +import io +import sys import time + import requests -import sys -import io # Fix Windows console encoding if sys.platform == "win32": - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") AGENT_URL = "http://localhost:8001" + def chat_with_agent(prompt: str): """Send a message to the agent and get the response.""" - print(f"\n{'='*80}") + print(f"\n{'=' * 80}") print(f"YOU: {prompt}") - print(f"{'='*80}\n") + print(f"{'=' * 80}\n") # Create a task task_request = { "jsonrpc": "2.0", "id": 1, "method": "agent.task.create", - "params": { - "message": { - "role": "user", - "parts": [ - { - "type": "text", - "text": prompt - } - ] - } - } + "params": {"message": {"role": "user", "parts": [{"type": "text", "text": prompt}]}}, } print("🤖 Agent is thinking...\n") @@ -79,9 +71,7 @@ def chat_with_agent(prompt: str): "jsonrpc": "2.0", "id": attempt + 2, "method": "agent.task.get", - "params": { - "task_id": task_id - } + "params": {"task_id": task_id}, } try: @@ -116,15 +106,15 @@ def chat_with_agent(prompt: str): print(f"\n✅ Task {status}") break - print(f"{'='*80}\n") + print(f"{'=' * 80}\n") def main(): """Run the test scenarios.""" - print("\n" + "="*80) - print(" "*25 + "RESERVATION AGENT DEMO") - print("="*80 + "\n") + print("\n" + "=" * 80) + print(" " * 25 + "RESERVATION AGENT DEMO") + print("=" * 80 + "\n") test_prompts = [ "Find Italian restaurants in Boston", @@ -138,9 +128,9 @@ def main(): print("3. 📝 Make reservations with guest details\n") for i, prompt in enumerate(test_prompts, 1): - print(f"\n{'#'*80}") + print(f"\n{'#' * 80}") print(f" DEMO {i}/{len(test_prompts)}") - print(f"{'#'*80}") + print(f"{'#' * 80}") chat_with_agent(prompt) @@ -148,9 +138,9 @@ def main(): print("\n⏸️ Press Enter to continue to next demo...") input() - print("\n" + "="*80) - print(" "*20 + "✅ DEMO COMPLETE!") - print("="*80 + "\n") + print("\n" + "=" * 80) + print(" " * 20 + "✅ DEMO COMPLETE!") + print("=" * 80 + "\n") if __name__ == "__main__": @@ -161,4 +151,5 @@ def main(): except Exception as e: print(f"\n\n❌ Error: {e}") import traceback + traceback.print_exc() diff --git a/a2a/simple_generalist/src/simple_generalist/a2a_server/server.py b/a2a/simple_generalist/src/simple_generalist/a2a_server/server.py index 646fe889..c8a08168 100644 --- a/a2a/simple_generalist/src/simple_generalist/a2a_server/server.py +++ b/a2a/simple_generalist/src/simple_generalist/a2a_server/server.py @@ -31,22 +31,22 @@ def get_agent_card(settings: Settings) -> AgentCard: """ Create the Agent Card for the Simple Generalist Agent. - + Args: settings: Application settings - + Returns: AgentCard describing the agent's capabilities """ capabilities = AgentCapabilities(streaming=True) - + # Create skill description skill_description = ( "A general-purpose agent that can use MCP tools to accomplish various tasks. " "The agent uses a function-calling loop to iteratively solve problems by calling tools " "and synthesizing results." ) - + skill = AgentSkill( id="simple_generalist_agent", name="Simple Generalist Agent", @@ -58,7 +58,7 @@ def get_agent_card(settings: Settings) -> AgentCard: "Perform multi-step operations with tool assistance", ], ) - + agent_url = settings.A2A_PUBLIC_URL if not agent_url: if settings.A2A_HOST == "0.0.0.0": @@ -80,16 +80,16 @@ def get_agent_card(settings: Settings) -> AgentCard: class SimpleGeneralistExecutor(AgentExecutor): """Agent executor for the Simple Generalist Agent.""" - + def __init__(self, settings: Settings): """ Initialize the executor. - + Args: settings: Application settings """ self.settings = settings - + async def _run_agent( self, user_input: str, @@ -117,7 +117,7 @@ async def _run_agent( async def execute(self, context: RequestContext, event_queue: EventQueue): """ Execute a task request. - + Args: context: Request context event_queue: Event queue for progress updates @@ -127,10 +127,10 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if not task: task = new_task(context.message) # type: ignore await event_queue.enqueue_event(task) - + # Create task updater for progress events task_updater = TaskUpdater(event_queue, task.id, task.context_id) - + # Create event callback async def event_callback(message: str, final: bool = False): """Send progress events to the client.""" @@ -158,11 +158,11 @@ async def error_callback(message: str): parts = [Part(root=TextPart(text=message))] await task_updater.add_artifact(parts) await task_updater.failed() - + # Extract user input user_input = context.get_user_input() logger.info(f"Processing request: {user_input}") - + # Hook up MCP tools (per-request connection like in a2a_agent.py) toolkit = None try: @@ -170,19 +170,20 @@ async def error_callback(message: str): if mcp_url: logger.info(f"Connecting to MCP server at {mcp_url}") - async with streamablehttp_client( - url=mcp_url, - timeout=30, - sse_read_timeout=300, - ) as ( - read_stream, - write_stream, - _, - ), ClientSession(read_stream, write_stream) as session: + async with ( + streamablehttp_client( + url=mcp_url, + timeout=30, + sse_read_timeout=300, + ) as ( + read_stream, + write_stream, + _, + ), + ClientSession(read_stream, write_stream) as session, + ): await session.initialize() - toolkit = await create_toolkit( - session=session, use_mcp_resources=False - ) + toolkit = await create_toolkit(session=session, use_mcp_resources=False) await self._run_agent( user_input, self.settings, @@ -198,17 +199,17 @@ async def error_callback(message: str): error_callback, toolkit, ) - + except Exception as exc: traceback.print_exc() logger.error(f"Error executing task: {exc}", exc_info=True) error_message = f"I encountered an error while processing your request: {str(exc)}" await error_callback(error_message) - + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """ Cancel a task (not implemented). - + Args: context: Request context event_queue: Event queue @@ -219,32 +220,33 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None def create_app(settings: Settings) -> Any: """ Create the A2A Starlette application. - + Args: settings: Application settings - + Returns: Starlette application """ # Create agent card agent_card = get_agent_card(settings) - + # Create request handler request_handler = DefaultRequestHandler( agent_executor=SimpleGeneralistExecutor(settings), task_store=InMemoryTaskStore(), ) - + # Create A2A server server = A2AStarletteApplication( agent_card=agent_card, http_handler=request_handler, ) - + # Build and return the app app = server.build() logger.info("A2A server application created") - + return app + # Made with Bob diff --git a/a2a/simple_generalist/src/simple_generalist/agent/generalist_agent.py b/a2a/simple_generalist/src/simple_generalist/agent/generalist_agent.py index 3200c2ca..801cd4ad 100644 --- a/a2a/simple_generalist/src/simple_generalist/agent/generalist_agent.py +++ b/a2a/simple_generalist/src/simple_generalist/agent/generalist_agent.py @@ -70,9 +70,7 @@ def _init_tracing() -> TracerProvider: _tracer_provider.add_span_processor(AgentIdSpanProcessor(_AGENT_IDS)) if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"): - _tracer_provider.add_span_processor( - BatchSpanProcessor(OTLPSpanExporter()) - ) + _tracer_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) logger.info("AG2 OpenTelemetry tracing enabled (OTLP endpoint: %s)", os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]) elif os.environ.get("OTEL_CONSOLE_TRACING", "").lower() in ("true", "1", "yes"): _tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter())) @@ -85,18 +83,16 @@ def _init_tracing() -> TracerProvider: return _tracer_provider - - class GeneralistAgent: """ Generalist agent that uses AG2 for LLM interaction and MCP tools for actions. - + - Maintains conversation state - Calls LLM for next action - Executes tools via MCP - Iterates until completion or limits """ - + def __init__( self, settings: Settings, @@ -105,7 +101,7 @@ def __init__( ): """ Initialize the generalist agent. - + Args: settings: Application settings mcp_toolkit: Optional AG2 MCP toolkit with connected servers @@ -118,7 +114,7 @@ def __init__( # Initialize AG2 agent self._init_ag2_agent() - + def _init_ag2_agent(self): """Initialize the AG2 conversable agent without registering tools.""" # Build LLM config @@ -132,16 +128,16 @@ def _init_ag2_agent(self): # Add API key if provided if self.settings.LLM_API_KEY: llm_config["api_key"] = self.settings.LLM_API_KEY - + # Add base URL if provided (for custom endpoints) if self.settings.LLM_BASE_URL: llm_config["base_url"] = self.settings.LLM_BASE_URL if self.settings.EXTRA_HEADERS: llm_config["default_headers"] = self.settings.EXTRA_HEADERS - + system_message = GENERAL_AGENT_PROMPT.format(max_steps=self.settings.MAX_ITERATIONS) - + # Create the agent self.agent = ConversableAgent( name="generalist_agent", @@ -169,8 +165,7 @@ def _init_ag2_agent(self): # Instrument agents for tracing instrument_agent(self.agent, tracer_provider=self._tracer_provider) instrument_agent(self.user_proxy, tracer_provider=self._tracer_provider) - - + async def _emit_event(self, message: str, final: bool = False): """Emit a progress event if callback is set.""" if self.event_callback: @@ -178,16 +173,16 @@ async def _emit_event(self, message: str, final: bool = False): await self.event_callback(message, final) except Exception as exc: logger.error(f"Error in event callback: {exc}") - + async def run_task(self, instruction: str) -> dict[str, Any]: """ Run a task with the given instruction. - + Uses AG2's built-in conversation flow with MCP tools. - + Args: instruction: User instruction/query - + Returns: Dictionary with: - answer: Final answer text @@ -196,29 +191,28 @@ async def run_task(self, instruction: str) -> dict[str, Any]: """ logger.info(f"Starting task: {instruction}") await self._emit_event("🤖 Starting task execution...") - + try: # Initiate chat - AG2 handles the tool calling loop - await self._emit_event(f"🔄 Processing with AG2 agent...") - + await self._emit_event("🔄 Processing with AG2 agent...") + # Run the synchronous initiate_chat in a thread pool to avoid blocking await self.user_proxy.a_initiate_chat( - self.agent, - message=instruction, - max_turns=self.settings.MAX_ITERATIONS) - + self.agent, message=instruction, max_turns=self.settings.MAX_ITERATIONS + ) + # Get the final response from chat history chat_history = self.user_proxy.chat_messages.get(self.agent, []) - + # Extract final answer from the last assistant message final_answer = "No response generated" for msg in reversed(chat_history): if msg.get("role") == "assistant" and msg.get("content"): final_answer = msg["content"] break - + logger.info("Task completed successfully") - + result = { "answer": final_answer, "iterations": len(chat_history), @@ -235,4 +229,5 @@ async def run_task(self, instruction: str) -> dict[str, Any]: "error": True, } + # Made with Bob diff --git a/a2a/simple_generalist/src/simple_generalist/agent/prompts.py b/a2a/simple_generalist/src/simple_generalist/agent/prompts.py index 3dec78ec..39356197 100644 --- a/a2a/simple_generalist/src/simple_generalist/agent/prompts.py +++ b/a2a/simple_generalist/src/simple_generalist/agent/prompts.py @@ -45,4 +45,4 @@ # Real Task Instruction -""" \ No newline at end of file +""" diff --git a/a2a/simple_generalist/src/simple_generalist/config/settings.py b/a2a/simple_generalist/src/simple_generalist/config/settings.py index adf19c33..0b9d2c94 100644 --- a/a2a/simple_generalist/src/simple_generalist/config/settings.py +++ b/a2a/simple_generalist/src/simple_generalist/config/settings.py @@ -13,13 +13,13 @@ class Settings(BaseSettings): """Application settings loaded from environment variables.""" - + # Logging - LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = Field( + LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default=os.getenv("LOG_LEVEL", "INFO"), # type: ignore[arg-type] description="Application log level", ) - + # A2A Server Configuration A2A_HOST: str = Field( default=os.getenv("A2A_HOST", "0.0.0.0"), @@ -29,14 +29,14 @@ class Settings(BaseSettings): default=int(os.getenv("A2A_PORT", "8000")), description="Port for A2A server", ) - + # MCP Server Configuration MCP_SERVER_URL: str = Field( default=os.getenv("MCP_SERVER_URL", ""), description="MCP server URL", validation_alias=AliasChoices("MCP_SERVER_URL", "MCP_SERVERS"), ) - + # LLM Configuration LLM_MODEL: str = Field( default=os.getenv("LLM_MODEL", "gpt-4"), @@ -65,7 +65,7 @@ class Settings(BaseSettings): ) EXTRA_HEADERS: dict[str, str] = Field( default_factory=dict, - description="Extra headers for the OpenAI API (JSON string, e.g. '{\"key\": \"value\"}')", + description='Extra headers for the OpenAI API (JSON string, e.g. \'{"key": "value"}\')', ) @field_validator("EXTRA_HEADERS", mode="before") @@ -84,7 +84,7 @@ def _parse_extra_headers(cls, v: Any) -> dict[str, str]: default=os.getenv("OTEL_CONSOLE_TRACING", "false").lower() in ("true", "1", "yes"), description="Print OpenTelemetry traces to console when no OTLP endpoint is configured", ) - + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", @@ -96,4 +96,5 @@ def load_settings() -> Settings: """Load and return application settings.""" return Settings() # type: ignore[call-arg] + # Made with Bob diff --git a/a2a/simple_generalist/src/simple_generalist/main.py b/a2a/simple_generalist/src/simple_generalist/main.py index 2324fb3d..b1abee8a 100644 --- a/a2a/simple_generalist/src/simple_generalist/main.py +++ b/a2a/simple_generalist/src/simple_generalist/main.py @@ -13,13 +13,13 @@ def setup_logging(settings: Settings): """ Setup logging configuration. - + Args: settings: Application settings """ logging.basicConfig( level=settings.LOG_LEVEL, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", stream=sys.stdout, ) @@ -28,26 +28,26 @@ def run(): """Run the Simple Generalist server.""" # Load settings settings = Settings() # type: ignore[call-arg] - + # Setup logging setup_logging(settings) - + logger.info("Starting Simple Generalist Agent") logger.info(f"LLM Model: {settings.LLM_MODEL}") logger.info(f"Max Iterations: {settings.MAX_ITERATIONS}") - + if settings.MCP_SERVER_URL.strip(): logger.info(f"MCP Server URL: {settings.MCP_SERVER_URL.strip()}") else: logger.warning("No MCP server configured - agent will run without tools") - + # Create A2A app (MCP connection will be established per-request) try: app = create_app(settings) except Exception as exc: logger.error(f"Failed to create A2A app: {exc}", exc_info=True) sys.exit(1) - + # Run server logger.info(f"Starting A2A server on {settings.A2A_HOST}:{settings.A2A_PORT}") uvicorn.run( diff --git a/a2a/slack_researcher/a2a_agent.py b/a2a/slack_researcher/a2a_agent.py index 92007b56..eace30c9 100644 --- a/a2a/slack_researcher/a2a_agent.py +++ b/a2a/slack_researcher/a2a_agent.py @@ -17,7 +17,15 @@ from a2a.server.events.event_queue import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater -from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart, SecurityScheme, HTTPAuthSecurityScheme +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + TaskState, + TextPart, + SecurityScheme, + HTTPAuthSecurityScheme, +) from a2a.utils import new_agent_text_message, new_task from starlette.routing import Route @@ -27,7 +35,8 @@ from slack_researcher.main import SlackAgent logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.DEBUG, stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig(level=logging.DEBUG, stream=sys.stdout, format="%(levelname)s: %(message)s") + def get_agent_card(host: str, port: int): """Returns the Agent Card for the AG2 Agent.""" @@ -54,10 +63,7 @@ def get_agent_card(host: str, port: int): securitySchemes={ "Bearer": SecurityScheme( root=HTTPAuthSecurityScheme( - type="http", - scheme="bearer", - bearerFormat="JWT", - description="OAuth 2.0 JWT token" + type="http", scheme="bearer", bearerFormat="JWT", description="OAuth 2.0 JWT token" ) ) }, @@ -110,13 +116,15 @@ class ResearchExecutor(AgentExecutor): """ A class to handle research execution for A2A Agent. """ - async def _run_agent(self, + + async def _run_agent( + self, messages: dict, settings: Settings, event_emitter: Event, assistant_tool_map: dict[str, Callable], - toolkit: Toolkit): - + toolkit: Toolkit, + ): slack_agent = SlackAgent( config=settings, eventer=event_emitter, @@ -163,30 +171,39 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if settings.MCP_URL: logging.info("Connecting to MCP server at %s", settings.MCP_URL) - async with streamablehttp_client( - url=settings.MCP_URL, - ) as ( - read_stream, - write_stream, - _, - ), ClientSession(read_stream, write_stream) as session: + async with ( + streamablehttp_client( + url=settings.MCP_URL, + ) as ( + read_stream, + write_stream, + _, + ), + ClientSession(read_stream, write_stream) as session, + ): await session.initialize() - toolkit = await create_toolkit( - session=session, use_mcp_resources=False - ) - await self._run_agent(messages, settings, + toolkit = await create_toolkit(session=session, use_mcp_resources=False) + await self._run_agent( + messages, + settings, event_emitter, assistant_tool_map, - toolkit,) + toolkit, + ) else: - await self._run_agent(messages, settings, + await self._run_agent( + messages, + settings, event_emitter, assistant_tool_map, - toolkit,) + toolkit, + ) except Exception as e: traceback.print_exc() - await event_emitter.emit_event(f"I'm sorry I was unable to fulfill your request. I encountered the following exception: {str(e)}", True) + await event_emitter.emit_event( + f"I'm sorry I was unable to fulfill your request. I encountered the following exception: {str(e)}", True + ) async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """ @@ -214,11 +231,14 @@ def run(): app = server.build() # this returns a Starlette app # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) uvicorn.run(app, host="0.0.0.0", port=settings.SERVICE_PORT) diff --git a/a2a/slack_researcher/slack_researcher/agents.py b/a2a/slack_researcher/slack_researcher/agents.py index 792d00b3..771272f2 100644 --- a/a2a/slack_researcher/slack_researcher/agents.py +++ b/a2a/slack_researcher/slack_researcher/agents.py @@ -2,7 +2,7 @@ import sys from typing import Callable -from autogen import coding, ConversableAgent, register_function +from autogen import ConversableAgent, register_function from autogen.mcp.mcp_client import Toolkit from slack_researcher.config import Settings, settings @@ -11,22 +11,20 @@ ASSISTANT_PROMPT, REQUIREMENT_IDENTIFIER_PROMPT, CHANNEL_FILTER_PROMPT, - SUMMARIZER_PROMPT + SUMMARIZER_PROMPT, ) logger = logging.getLogger(__name__) -logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format="%(levelname)s: %(message)s") class Agents: - def __init__( self, config: Settings = None, assistant_tools: dict[str, Callable] = None, mcp_toolkit: Toolkit = None, ): - if not config: config = Settings() @@ -108,7 +106,7 @@ def __init__( mcp_toolkit.register_for_llm(self.slack_channel_assistant) tool_descriptions = [] for tool in mcp_toolkit.tools: - tool_descriptions.append({tool.name : tool.description}) + tool_descriptions.append({tool.name: tool.description}) tool_descriptions = str(tool_descriptions) logging.info("Tool descriptions: %s", tool_descriptions) else: diff --git a/a2a/slack_researcher/slack_researcher/config.py b/a2a/slack_researcher/slack_researcher/config.py index 7b10ed97..4c5c324e 100644 --- a/a2a/slack_researcher/slack_researcher/config.py +++ b/a2a/slack_researcher/slack_researcher/config.py @@ -1,13 +1,13 @@ import json -import logging import os from pydantic_settings import BaseSettings from pydantic import model_validator from pydantic import Field -from typing import Literal, Optional +from typing import Literal + class Settings(BaseSettings): - LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = Field( + LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( os.getenv("LOG_LEVEL", "DEBUG"), description="Application log level", ) @@ -31,7 +31,9 @@ class Settings(BaseSettings): description="The maximum number of plan steps", ge=1, ) - MCP_URL: str = Field(os.getenv("MCP_URL", "http://slack-tool:8000"), description="Endpoint for an option MCP server") + MCP_URL: str = Field( + os.getenv("MCP_URL", "http://slack-tool:8000"), description="Endpoint for an option MCP server" + ) SERVICE_PORT: int = Field(os.getenv("SERVICE_URL", 8000), description="Port on which the service will run.") class Config: @@ -47,4 +49,5 @@ def validate_extra_headers(self) -> "Settings": raise ValueError("EXTRA_HEADERS must be a valid JSON string") return self + settings = Settings() # type: ignore[call-arg] diff --git a/a2a/slack_researcher/slack_researcher/data_types.py b/a2a/slack_researcher/slack_researcher/data_types.py index b745c27f..a870dcd4 100644 --- a/a2a/slack_researcher/slack_researcher/data_types.py +++ b/a2a/slack_researcher/slack_researcher/data_types.py @@ -5,20 +5,33 @@ # Pydantic types for LLM response formats ############ + class ChannelInfo(BaseModel): name: str = Field(description="The name of the slack channel") id: str = Field(description="The ID of the channel") description: str = Field(description="A description of the channel") + class ChannelList(BaseModel): channels: list[ChannelInfo] - explanation: Optional[str] = Field(None, description="A detailed explanation as to why you chose this specific list of channels") + explanation: Optional[str] = Field( + None, description="A detailed explanation as to why you chose this specific list of channels" + ) + class UserIntent(BaseModel): intent: Literal["LIST_CHANNELS", "QUERY CHANNELS"] + class UserRequirement(BaseModel): - specific_channel_names: Optional[str] = Field(None, description="Specific channel names that the user would like to fetch") - types_of_channels: str = Field(description="A description of the types of channels that the user would like information about, if their request is not limited to specific channel names.") - types_of_information_to_search: Optional[str] = Field(None, description="The types of information that the user wants to look for inside of the channels. \ - Can be null if user is only interested in listing channels and not interested in searching inside channel content.") \ No newline at end of file + specific_channel_names: Optional[str] = Field( + None, description="Specific channel names that the user would like to fetch" + ) + types_of_channels: str = Field( + description="A description of the types of channels that the user would like information about, if their request is not limited to specific channel names." + ) + types_of_information_to_search: Optional[str] = Field( + None, + description="The types of information that the user wants to look for inside of the channels. \ + Can be null if user is only interested in listing channels and not interested in searching inside channel content.", + ) diff --git a/a2a/slack_researcher/slack_researcher/llm.py b/a2a/slack_researcher/slack_researcher/llm.py index 7ad0911e..e8e0e492 100644 --- a/a2a/slack_researcher/slack_researcher/llm.py +++ b/a2a/slack_researcher/slack_researcher/llm.py @@ -21,16 +21,8 @@ def _create_llm_config(self, config, response_format): "config_list": [ { **self._base_config, - **( - {"response_format": response_format} - if response_format - else {} - ), - **( - {"default_headers": config.EXTRA_HEADERS} - if config.EXTRA_HEADERS - else {} - ), + **({"response_format": response_format} if response_format else {}), + **({"default_headers": config.EXTRA_HEADERS} if config.EXTRA_HEADERS else {}), } ], "temperature": config.MODEL_TEMPERATURE, diff --git a/a2a/slack_researcher/slack_researcher/main.py b/a2a/slack_researcher/slack_researcher/main.py index 32fea120..7e594806 100644 --- a/a2a/slack_researcher/slack_researcher/main.py +++ b/a2a/slack_researcher/slack_researcher/main.py @@ -1,4 +1,3 @@ -import asyncio import json import logging import sys @@ -11,15 +10,18 @@ logger = logging.getLogger(__name__) -logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig(level=settings.LOG_LEVEL, stream=sys.stdout, format="%(levelname)s: %(message)s") + class SlackAgent: - def __init__(self, config: Settings, + def __init__( + self, + config: Settings, eventer: Event = None, assistant_tools: dict[str, Callable] = None, mcp_toolkit: Toolkit = None, - logger=None,): - + logger=None, + ): self.agents = Agents(settings, assistant_tools, mcp_toolkit) self.eventer = eventer self.logger = logger or logging.getLogger(__name__) @@ -37,7 +39,7 @@ async def _send_event(self, message: str, final: bool = False): await self.eventer.emit_event(message, final) else: self.logger.warning("No event handler registered") - + async def execute(self, user_query): self.user_query = self.extract_user_input(user_query) await self.classify_intent() @@ -65,29 +67,37 @@ def extract_user_input(self, body): self.logger.warning(f"Ignoring content with type {item['type']}") return latest_content - + async def classify_intent(self): prompt = f"Classify the intent of the user as either simply needing to list slack channel information or if their intent is querying the content of slack channels themselves. User query: {self.user_query}" - response = await self.agents.user_proxy.a_initiate_chat(message=prompt, recipient=self.agents.intent_classifier, max_turns=1) + response = await self.agents.user_proxy.a_initiate_chat( + message=prompt, recipient=self.agents.intent_classifier, max_turns=1 + ) self.user_intent = UserIntent(**json.loads(response.chat_history[-1]["content"])) await self._send_event(f"🧐 Identified user intent: {self.user_intent.intent}") async def identify_requirements(self): - response = await self.agents.user_proxy.a_initiate_chat(message=self.user_query, recipient=self.agents.requirement_identifier, max_turns=1) + response = await self.agents.user_proxy.a_initiate_chat( + message=self.user_query, recipient=self.agents.requirement_identifier, max_turns=1 + ) self.requirements = UserRequirement(**json.loads(response.chat_history[-1]["content"])) - await self._send_event(f"📇 Identified channel requirements. Channel names: {self.requirements.specific_channel_names}, Channel types: {self.requirements.types_of_channels}") + await self._send_event( + f"📇 Identified channel requirements. Channel names: {self.requirements.specific_channel_names}, Channel types: {self.requirements.types_of_channels}" + ) async def list_all_channels(self): await self._send_event("🔎 Fetching all channels") - response = await self.agents.user_proxy.a_initiate_chat(message="Retrieve all slack channels that are found on my slack server. Use the slack tool to find it.", - recipient=self.agents.slack_channel_assistant, - max_turns=3) + response = await self.agents.user_proxy.a_initiate_chat( + message="Retrieve all slack channels that are found on my slack server. Use the slack tool to find it.", + recipient=self.agents.slack_channel_assistant, + max_turns=3, + ) for item in response.chat_history: if item.get("tool_responses"): for tool_response in item["tool_responses"]: self.all_channels += tool_response.get("content") return response - + async def get_relevant_channels(self): self._send_event("👀 Identifying relevant channels") prompt = "" @@ -99,17 +109,22 @@ async def get_relevant_channels(self): prompt += f"User is looking for channels of any name that meet the following criteria: {self.requirements.types_of_channels}" prompt += f"\n The list of slack channels is as follows: {self.all_channels}" - response = await self.agents.user_proxy.a_initiate_chat(message=prompt, recipient=self.agents.channel_assistant_no_tools, max_turns=1) + response = await self.agents.user_proxy.a_initiate_chat( + message=prompt, recipient=self.agents.channel_assistant_no_tools, max_turns=1 + ) self.relevant_channels = ChannelList(**json.loads(response.chat_history[-1]["content"])) channel_names = [channel.name for channel in self.relevant_channels.channels] - await self._send_event(f"🎯 Relevant channels identified: {channel_names}. Reason: {self.relevant_channels.explanation}") - + await self._send_event( + f"🎯 Relevant channels identified: {channel_names}. Reason: {self.relevant_channels.explanation}" + ) async def query_channel(self, channel: ChannelInfo): await self._send_event(f"📖 Querying channel {channel.name}") - prompt = f"Retrieve the history from the slack channel with ID \"{channel.id}\" using the Slack tool available to you. The data retrieved will be used to answer the following user query/instruction: {self.user_query}" - response = await self.agents.user_proxy.a_initiate_chat(message=prompt, recipient=self.agents.slack_channel_assistant, max_turns=3) + prompt = f'Retrieve the history from the slack channel with ID "{channel.id}" using the Slack tool available to you. The data retrieved will be used to answer the following user query/instruction: {self.user_query}' + response = await self.agents.user_proxy.a_initiate_chat( + message=prompt, recipient=self.agents.slack_channel_assistant, max_turns=3 + ) # We're going to capture the raw channel data for analysis later channel_data = "" @@ -126,9 +141,11 @@ async def query_channel(self, channel: ChannelInfo): async def query_channels(self): for channel in self.relevant_channels.channels: self.channel_outputs.append(await self.query_channel(channel)) - + async def summarize_data(self, data_to_summarize): - await self._send_event(f"📄 Generating a final report") + await self._send_event("📄 Generating a final report") prompt = f"User query: {self.user_query}. \n Information gathered: {data_to_summarize}" - response = await self.agents.user_proxy.a_initiate_chat(message=prompt, recipient=self.agents.report_generator, max_turns=1) + response = await self.agents.user_proxy.a_initiate_chat( + message=prompt, recipient=self.agents.report_generator, max_turns=1 + ) return response.chat_history[-1]["content"] diff --git a/a2a/slack_researcher/slack_researcher/prompts.py b/a2a/slack_researcher/slack_researcher/prompts.py index 29b16637..2bb7800b 100644 --- a/a2a/slack_researcher/slack_researcher/prompts.py +++ b/a2a/slack_researcher/slack_researcher/prompts.py @@ -140,4 +140,4 @@ You are a helpful assistant who will produce a detailed report to directly address the user's query. You will use ONLY the following data that has been gathered from slack. Where possible, identify names of channels and names of users, not just their IDs. If you are unable to answer or only able to partially answer due to missing information or a specific error, please give detail to this. -""" \ No newline at end of file +""" diff --git a/a2a/trivia_agent/src/trivia_agent/__init__.py b/a2a/trivia_agent/src/trivia_agent/__init__.py index 6086822a..ee3c8a6c 100644 --- a/a2a/trivia_agent/src/trivia_agent/__init__.py +++ b/a2a/trivia_agent/src/trivia_agent/__init__.py @@ -1,16 +1,20 @@ -from opentelemetry.sdk.resources import Resource from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + def setup_tracer(): - resource = Resource.create(attributes={ - "service.name": "trivia-agent", - }) + resource = Resource.create( + attributes={ + "service.name": "trivia-agent", + } + ) provider = TracerProvider(resource=resource) processor = BatchSpanProcessor(OTLPSpanExporter()) provider.add_span_processor(processor) trace.set_tracer_provider(provider) + setup_tracer() diff --git a/a2a/trivia_agent/src/trivia_agent/agent.py b/a2a/trivia_agent/src/trivia_agent/agent.py index 41131915..7bb88e64 100644 --- a/a2a/trivia_agent/src/trivia_agent/agent.py +++ b/a2a/trivia_agent/src/trivia_agent/agent.py @@ -1,10 +1,10 @@ import logging +from textwrap import dedent import uvicorn from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route -from textwrap import dedent from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication @@ -13,7 +13,6 @@ from a2a.server.tasks import InMemoryTaskStore, TaskUpdater from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task - from trivia_agent.trivia_agent_llm import chat logging.basicConfig(level=logging.DEBUG) diff --git a/a2a/trivia_agent/src/trivia_agent/configuration.py b/a2a/trivia_agent/src/trivia_agent/configuration.py index d837750a..5b009f42 100644 --- a/a2a/trivia_agent/src/trivia_agent/configuration.py +++ b/a2a/trivia_agent/src/trivia_agent/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "qwen3:4b" llm_api_base: str = "http://host.docker.internal:11434/v1" diff --git a/a2a/weather_service/src/weather_service/agent.py b/a2a/weather_service/src/weather_service/agent.py index 54f17736..3b555579 100644 --- a/a2a/weather_service/src/weather_service/agent.py +++ b/a2a/weather_service/src/weather_service/agent.py @@ -1,22 +1,25 @@ import logging import os -import uvicorn from textwrap import dedent +import uvicorn +from langchain_core.messages import HumanMessage +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.routing import Route + from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2AStarletteApplication from a2a.server.events.event_queue import EventQueue -from starlette.routing import Route from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore, TaskUpdater from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task -from langchain_core.messages import HumanMessage - -from starlette.middleware.base import BaseHTTPMiddleware - from weather_service.graph import get_graph, get_mcpclient -from weather_service.observability import create_tracing_middleware, set_span_output, get_root_span +from weather_service.observability import ( + create_tracing_middleware, + get_root_span, + set_span_output, +) logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -56,6 +59,7 @@ def get_agent_card(host: str, port: int): skills=[skill], ) + class A2AEvent: """ A class to handle events for A2A Agent. @@ -87,10 +91,12 @@ async def emit_event(self, message: str, final: bool = False, failed: bool = Fal ), ) + class WeatherExecutor(AgentExecutor): """ A class to handle weather assistant execution for A2A Agent. """ + async def execute(self, context: RequestContext, event_queue: EventQueue): """ The agent allows to retrieve weather info through a natural language conversational interface @@ -110,30 +116,33 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): # Parse Messages messages = [HumanMessage(content=user_input)] input = {"messages": messages} - logger.info(f'Processing messages: {input}') + logger.info(f"Processing messages: {input}") # Note: Root span with MLflow attributes is created by tracing middleware # Here we just run the agent logic - spans from LangChain are auto-captured output = None # Test MCP connection first - logger.info(f'Attempting to connect to MCP server at: {os.getenv("MCP_URL", "http://localhost:8000/sse")}') + logger.info(f"Attempting to connect to MCP server at: {os.getenv('MCP_URL', 'http://localhost:8000/sse')}") mcpclient = get_mcpclient() # Try to get tools to verify connection try: tools = await mcpclient.get_tools() - logger.info(f'Successfully connected to MCP server. Available tools: {[tool.name for tool in tools]}') + logger.info(f"Successfully connected to MCP server. Available tools: {[tool.name for tool in tools]}") except Exception as tool_error: - logger.error(f'Failed to connect to MCP server: {tool_error}') - await event_emitter.emit_event(f"Error: Cannot connect to MCP weather service at {os.getenv('MCP_URL', 'http://localhost:8000/sse')}. Please ensure the weather MCP server is running. Error: {tool_error}", failed=True) + logger.error(f"Failed to connect to MCP server: {tool_error}") + await event_emitter.emit_event( + f"Error: Cannot connect to MCP weather service at {os.getenv('MCP_URL', 'http://localhost:8000/sse')}. Please ensure the weather MCP server is running. Error: {tool_error}", + failed=True, + ) return try: graph = await get_graph(mcpclient) except Exception as graph_error: - logger.error(f'Failed to create LLM graph: {graph_error}') + logger.error(f"Failed to create LLM graph: {graph_error}") await event_emitter.emit_event(f"Error: Failed to initialize LLM graph: {graph_error}", failed=True) return @@ -147,9 +156,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): + "\n" ) output = event - logger.info(f'event: {event}') + logger.info(f"event: {event}") except Exception as llm_error: - logger.error(f'LLM execution failed: {llm_error}') + logger.error(f"LLM execution failed: {llm_error}") await event_emitter.emit_event(f"Error: LLM execution failed: {llm_error}", failed=True) return @@ -172,6 +181,7 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None """ raise Exception("cancel not supported") + def run(): """ Runs the A2A Agent application. @@ -192,12 +202,15 @@ def run(): app = server.build() # Add the new agent-card.json path alongside the legacy agent.json path - app.routes.insert(0, Route( - '/.well-known/agent-card.json', - server._handle_get_agent_card, - methods=['GET'], - name='agent_card_new', - )) + app.routes.insert( + 0, + Route( + "/.well-known/agent-card.json", + server._handle_get_agent_card, + methods=["GET"], + name="agent_card_new", + ), + ) # Add tracing middleware - creates root span with MLflow/GenAI attributes app.add_middleware(BaseHTTPMiddleware, dispatch=create_tracing_middleware()) @@ -206,7 +219,9 @@ def run(): @app.middleware("http") async def log_authorization_header(request, call_next): auth_header = request.headers.get("authorization", "No Authorization header") - logger.info(f"🔐 Incoming request to {request.url.path} with Authorization: {auth_header[:80] + '...' if len(auth_header) > 80 else auth_header}") + logger.info( + f"🔐 Incoming request to {request.url.path} with Authorization: {auth_header[:80] + '...' if len(auth_header) > 80 else auth_header}" + ) response = await call_next(request) return response diff --git a/a2a/weather_service/src/weather_service/configuration.py b/a2a/weather_service/src/weather_service/configuration.py index 0c7dcb66..dd1c9f4c 100644 --- a/a2a/weather_service/src/weather_service/configuration.py +++ b/a2a/weather_service/src/weather_service/configuration.py @@ -1,5 +1,6 @@ from pydantic_settings import BaseSettings + class Configuration(BaseSettings): llm_model: str = "llama3.1" llm_api_base: str = "http://localhost:11434/v1" diff --git a/a2a/weather_service/src/weather_service/graph.py b/a2a/weather_service/src/weather_service/graph.py index 3e5d671d..cd9e8cf2 100644 --- a/a2a/weather_service/src/weather_service/graph.py +++ b/a2a/weather_service/src/weather_service/graph.py @@ -1,24 +1,31 @@ -from langgraph.graph import StateGraph, MessagesState, START +import os + +from langchain_core.messages import AIMessage, SystemMessage from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_core.messages import SystemMessage, AIMessage -from langgraph.prebuilt import tools_condition, ToolNode from langchain_openai import ChatOpenAI -import os +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + from weather_service.configuration import Configuration config = Configuration() + # Extend MessagesState to include a final answer class ExtendedMessagesState(MessagesState): - final_answer: str = "" + final_answer: str = "" + def get_mcpclient(): - return MultiServerMCPClient({ - "math": { - "url": os.getenv("MCP_URL", "http://localhost:8000/mcp"), - "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + return MultiServerMCPClient( + { + "math": { + "url": os.getenv("MCP_URL", "http://localhost:8000/mcp"), + "transport": os.getenv("MCP_TRANSPORT", "streamable_http"), + } } - }) + ) + async def get_graph(client) -> StateGraph: llm = ChatOpenAI( @@ -33,7 +40,9 @@ async def get_graph(client) -> StateGraph: llm_with_tools = llm.bind_tools(tools) # System message - sys_msg = SystemMessage(content="You are a helpful assistant tasked with providing weather information. You must use the provided tools to complete your task.") + sys_msg = SystemMessage( + content="You are a helpful assistant tasked with providing weather information. You must use the provided tools to complete your task." + ) # Node def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: @@ -61,6 +70,7 @@ def assistant(state: ExtendedMessagesState) -> ExtendedMessagesState: graph = builder.compile() return graph + # async def main(): # from langchain_core.messages import HumanMessage # client = get_mcpclient() diff --git a/a2a/weather_service/src/weather_service/observability.py b/a2a/weather_service/src/weather_service/observability.py index e713de94..fb2154b7 100644 --- a/a2a/weather_service/src/weather_service/observability.py +++ b/a2a/weather_service/src/weather_service/observability.py @@ -11,18 +11,19 @@ import json import logging import os -from contextvars import ContextVar -from typing import Dict, Any, Optional from contextlib import contextmanager -from opentelemetry import trace, context +from contextvars import ContextVar +from typing import Dict, Optional + +from opentelemetry import context, trace +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.propagate import extract, set_global_textmap +from opentelemetry.propagators.composite import CompositePropagator +from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_VERSION, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.sdk.resources import Resource, SERVICE_NAME, SERVICE_VERSION -from opentelemetry.trace import Status, StatusCode, SpanKind -from opentelemetry.propagate import set_global_textmap, extract -from opentelemetry.propagators.composite import CompositePropagator +from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from opentelemetry.baggage.propagation import W3CBaggagePropagator logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ # ContextVar to pass root span from middleware to agent code # This allows execute() to access the middleware-created root span # even though trace.get_current_span() would return a child span -_root_span_var: ContextVar = ContextVar('root_span', default=None) +_root_span_var: ContextVar = ContextVar("root_span", default=None) def get_root_span(): @@ -48,9 +49,11 @@ def get_root_span(): """ return _root_span_var.get() + # OpenInference semantic conventions try: - from openinference.semconv.trace import SpanAttributes, OpenInferenceSpanKindValues + from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes + OPENINFERENCE_AVAILABLE = True except ImportError: OPENINFERENCE_AVAILABLE = False @@ -60,6 +63,7 @@ def get_root_span(): def _get_otlp_exporter(endpoint: str): """Get HTTP OTLP exporter.""" from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + if not endpoint.endswith("/v1/traces"): endpoint = endpoint.rstrip("/") + "/v1/traces" return OTLPSpanExporter(endpoint=endpoint) @@ -75,7 +79,7 @@ def setup_observability() -> None: namespace = os.getenv("K8S_NAMESPACE_NAME", "team1") otlp_endpoint = os.getenv( "OTEL_EXPORTER_OTLP_ENDPOINT", - "http://otel-collector.kagenti-system.svc.cluster.local:8335" + "http://otel-collector.kagenti-system.svc.cluster.local:8335", ) logger.info("=" * 60) @@ -88,46 +92,52 @@ def setup_observability() -> None: # Create resource with service and MLflow attributes # Resource attributes are STATIC and apply to ALL spans/traces # See: https://mlflow.org/docs/latest/genai/tracing/opentelemetry/ - resource = Resource(attributes={ - # Standard OTEL service attributes - SERVICE_NAME: service_name, - SERVICE_VERSION: AGENT_VERSION, - "service.namespace": namespace, - "k8s.namespace.name": namespace, - # MLflow static metadata (applies to all traces) - # These appear in MLflow trace list columns - "mlflow.traceName": AGENT_NAME, - "mlflow.source": service_name, - # GenAI static attributes - "gen_ai.agent.name": AGENT_NAME, - "gen_ai.agent.version": AGENT_VERSION, - "gen_ai.system": AGENT_FRAMEWORK, - }) + resource = Resource( + attributes={ + # Standard OTEL service attributes + SERVICE_NAME: service_name, + SERVICE_VERSION: AGENT_VERSION, + "service.namespace": namespace, + "k8s.namespace.name": namespace, + # MLflow static metadata (applies to all traces) + # These appear in MLflow trace list columns + "mlflow.traceName": AGENT_NAME, + "mlflow.source": service_name, + # GenAI static attributes + "gen_ai.agent.name": AGENT_NAME, + "gen_ai.agent.version": AGENT_VERSION, + "gen_ai.system": AGENT_FRAMEWORK, + } + ) # Create and configure tracer provider tracer_provider = TracerProvider(resource=resource) - tracer_provider.add_span_processor( - BatchSpanProcessor(_get_otlp_exporter(otlp_endpoint)) - ) + tracer_provider.add_span_processor(BatchSpanProcessor(_get_otlp_exporter(otlp_endpoint))) trace.set_tracer_provider(tracer_provider) # Auto-instrument LangChain with OpenInference try: from openinference.instrumentation.langchain import LangChainInstrumentor + LangChainInstrumentor().instrument() logger.info("LangChain instrumented with OpenInference") except ImportError: logger.warning("openinference-instrumentation-langchain not available") # Configure W3C Trace Context propagation - set_global_textmap(CompositePropagator([ - TraceContextTextMapPropagator(), - W3CBaggagePropagator(), - ])) + set_global_textmap( + CompositePropagator( + [ + TraceContextTextMapPropagator(), + W3CBaggagePropagator(), + ] + ) + ) # Instrument OpenAI for GenAI semantic conventions try: from opentelemetry.instrumentation.openai import OpenAIInstrumentor + OpenAIInstrumentor().instrument() logger.info("OpenAI instrumented with GenAI semantic conventions") except ImportError: @@ -168,7 +178,7 @@ def _set_genai_mlflow_attributes( if OPENINFERENCE_AVAILABLE: span.set_attribute( SpanAttributes.OPENINFERENCE_SPAN_KIND, - OpenInferenceSpanKindValues.AGENT.value + OpenInferenceSpanKindValues.AGENT.value, ) # === MLflow-specific Attributes === @@ -390,9 +400,9 @@ def create_tracing_middleware(): app = server.build() app.add_middleware(BaseHTTPMiddleware, dispatch=create_tracing_middleware()) """ + from starlette.requests import Request from starlette.responses import Response, StreamingResponse - import io async def tracing_middleware(request: Request, call_next): # Skip non-API paths (health checks, agent card, etc.) @@ -493,9 +503,7 @@ async def tracing_middleware(request: Request, call_next): # Try to capture response for output attributes # Note: This only works for non-streaming responses - if isinstance(response, Response) and not isinstance( - response, StreamingResponse - ): + if isinstance(response, Response) and not isinstance(response, StreamingResponse): # Read response body - we MUST recreate response after this response_body = b"" async for chunk in response.body_iterator: @@ -512,15 +520,9 @@ async def tracing_middleware(request: Request, call_next): if parts: output_text = parts[0].get("text", "") if output_text: - span.set_attribute( - "gen_ai.completion", output_text[:1000] - ) - span.set_attribute( - "output.value", output_text[:1000] - ) - span.set_attribute( - "mlflow.spanOutputs", output_text[:1000] - ) + span.set_attribute("gen_ai.completion", output_text[:1000]) + span.set_attribute("output.value", output_text[:1000]) + span.set_attribute("mlflow.spanOutputs", output_text[:1000]) except Exception as e: logger.debug(f"Could not parse response body: {e}") diff --git a/mcp/appworld_apis/entrypoint.py b/mcp/appworld_apis/entrypoint.py index d376aba8..b2207404 100644 --- a/mcp/appworld_apis/entrypoint.py +++ b/mcp/appworld_apis/entrypoint.py @@ -2,8 +2,8 @@ import signal import sys import threading -from inspect import signature from importlib import import_module +from inspect import signature from multiprocessing import Process from appworld import update_root @@ -44,8 +44,7 @@ def _coerce_db_path_for_docker_mode(path: str | None, appworld_root: str) -> str if _ensure_under(data_root, resolved) or _ensure_under(outputs_root, resolved): return resolved raise ValueError( - "DB path is outside allowed roots. " - "Allowed: APPWORLD_ROOT/data and APPWORLD_ROOT/experiments/outputs" + "DB path is outside allowed roots. Allowed: APPWORLD_ROOT/data and APPWORLD_ROOT/experiments/outputs" ) diff --git a/mcp/cloud_storage_tool/__init__.py b/mcp/cloud_storage_tool/__init__.py index 9ebcf029..357cd0e2 100644 --- a/mcp/cloud_storage_tool/__init__.py +++ b/mcp/cloud_storage_tool/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/mcp/cloud_storage_tool/cloud_storage_tool.py b/mcp/cloud_storage_tool/cloud_storage_tool.py index 5899d53d..3d2b7dc0 100644 --- a/mcp/cloud_storage_tool/cloud_storage_tool.py +++ b/mcp/cloud_storage_tool/cloud_storage_tool.py @@ -2,15 +2,20 @@ import logging import os import sys -from typing import List, Dict, Any, Tuple +from typing import Any, Dict, List, Tuple + +import boto3 +from azure.storage.blob import BlobServiceClient from fastmcp import FastMCP from google.cloud import storage from google.oauth2 import service_account -import boto3 -from azure.storage.blob import BlobServiceClient logger = logging.getLogger(__name__) -logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"), stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + stream=sys.stdout, + format="%(levelname)s: %(message)s", +) # GCP credentials @@ -27,6 +32,7 @@ AZURE_STORAGE_ACCOUNT_NAME = os.getenv("AZURE_STORAGE_ACCOUNT_NAME") AZURE_STORAGE_ACCOUNT_KEY = os.getenv("AZURE_STORAGE_ACCOUNT_KEY") + def parse_cloud_uri(uri: str) -> Tuple[str, str, str]: """Parse cloud storage URI and return (provider, bucket/container, path).""" if uri.startswith("gs://"): @@ -42,13 +48,14 @@ def parse_cloud_uri(uri: str) -> Tuple[str, str, str]: # If no scheme, raise error raise ValueError(f"Invalid cloud storage URI: {uri}") + def get_gcs_client(): """Create and return a GCS client using service account credentials.""" try: if GCP_SERVICE_ACCOUNT_KEY is None: logger.error("GCP_SERVICE_ACCOUNT_KEY environment variable not set") return None - + # Parse service account key from JSON string or file path if GCP_SERVICE_ACCOUNT_KEY.startswith("{"): # It's a JSON string @@ -57,7 +64,7 @@ def get_gcs_client(): else: # It's a file path credentials = service_account.Credentials.from_service_account_file(GCP_SERVICE_ACCOUNT_KEY) - + client = storage.Client(credentials=credentials, project=GCP_PROJECT_ID) logger.info("Successfully authenticated with GCP") return client @@ -65,26 +72,28 @@ def get_gcs_client(): logger.error(f"Error authenticating with GCP: {e}") return None + def get_s3_client(): """Create and return an S3 client using AWS credentials.""" try: if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY: client = boto3.client( - 's3', + "s3", aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION + region_name=AWS_REGION, ) else: # Use default credentials (IAM role, environment, etc.) - client = boto3.client('s3', region_name=AWS_REGION) - + client = boto3.client("s3", region_name=AWS_REGION) + logger.info("Successfully authenticated with AWS S3") return client except Exception as e: logger.error(f"Error authenticating with AWS S3: {e}") return None + def get_azure_blob_service_client(): """Create and return an Azure Blob Service client.""" try: @@ -96,281 +105,313 @@ def get_azure_blob_service_client(): else: logger.error("Azure credentials not configured") return None - + logger.info("Successfully authenticated with Azure Blob Storage") return client except Exception as e: logger.error(f"Error authenticating with Azure Blob Storage: {e}") return None + def list_objects_unified(provider: str, bucket_or_container: str) -> List[Dict[str, Any]]: """List objects from any cloud provider.""" objects = [] - + if provider == "gcs": storage_client = get_gcs_client() if not storage_client: raise Exception("Could not authenticate with GCP") - + bucket = storage_client.bucket(bucket_or_container) blobs = bucket.list_blobs() - + for blob in blobs: - objects.append({ - "name": blob.name, - "size": blob.size, - "content_type": blob.content_type, - "created": blob.time_created.isoformat() if blob.time_created else None, - "updated": blob.updated.isoformat() if blob.updated else None, - "storage_class": blob.storage_class, - "public_url": blob.public_url - }) - + objects.append( + { + "name": blob.name, + "size": blob.size, + "content_type": blob.content_type, + "created": blob.time_created.isoformat() if blob.time_created else None, + "updated": blob.updated.isoformat() if blob.updated else None, + "storage_class": blob.storage_class, + "public_url": blob.public_url, + } + ) + elif provider == "s3": s3_client = get_s3_client() if not s3_client: raise Exception("Could not authenticate with AWS S3") - - paginator = s3_client.get_paginator('list_objects_v2') + + paginator = s3_client.get_paginator("list_objects_v2") for page in paginator.paginate(Bucket=bucket_or_container): - for obj in page.get('Contents', []): - objects.append({ - "name": obj['Key'], - "size": obj['Size'], - "content_type": None, - "created": obj['LastModified'].isoformat() if 'LastModified' in obj else None, - "updated": obj['LastModified'].isoformat() if 'LastModified' in obj else None, - "storage_class": obj.get('StorageClass'), - "public_url": f"s3://{bucket_or_container}/{obj['Key']}" - }) - + for obj in page.get("Contents", []): + objects.append( + { + "name": obj["Key"], + "size": obj["Size"], + "content_type": None, + "created": obj["LastModified"].isoformat() if "LastModified" in obj else None, + "updated": obj["LastModified"].isoformat() if "LastModified" in obj else None, + "storage_class": obj.get("StorageClass"), + "public_url": f"s3://{bucket_or_container}/{obj['Key']}", + } + ) + elif provider == "azure": azure_client = get_azure_blob_service_client() if not azure_client: raise Exception("Could not authenticate with Azure Blob Storage") - + container_client = azure_client.get_container_client(bucket_or_container) blobs = container_client.list_blobs() - + for blob in blobs: - objects.append({ - "name": blob.name, - "size": blob.size, - "content_type": blob.content_settings.content_type if blob.content_settings else None, - "created": blob.creation_time.isoformat() if blob.creation_time else None, - "updated": blob.last_modified.isoformat() if blob.last_modified else None, - "storage_class": blob.blob_tier, - "public_url": f"azure://{bucket_or_container}/{blob.name}" - }) - + objects.append( + { + "name": blob.name, + "size": blob.size, + "content_type": blob.content_settings.content_type if blob.content_settings else None, + "created": blob.creation_time.isoformat() if blob.creation_time else None, + "updated": blob.last_modified.isoformat() if blob.last_modified else None, + "storage_class": blob.blob_tier, + "public_url": f"azure://{bucket_or_container}/{blob.name}", + } + ) + return objects -def copy_object_unified(provider: str, source_bucket: str, source_path: str, - target_bucket: str, target_path: str) -> bool: + +def copy_object_unified( + provider: str, + source_bucket: str, + source_path: str, + target_bucket: str, + target_path: str, +) -> bool: """Copy object within the same cloud provider.""" if provider == "gcs": storage_client = get_gcs_client() if not storage_client: raise Exception("Could not authenticate with GCP") - + source_bucket_obj = storage_client.bucket(source_bucket) source_blob = source_bucket_obj.blob(source_path) - + if not source_blob.exists(): raise Exception(f"Source file does not exist: gs://{source_bucket}/{source_path}") - + target_bucket_obj = storage_client.bucket(target_bucket) source_bucket_obj.copy_blob(source_blob, target_bucket_obj, target_path) return True - + elif provider == "s3": s3_client = get_s3_client() if not s3_client: raise Exception("Could not authenticate with AWS S3") - - copy_source = {'Bucket': source_bucket, 'Key': source_path} + + copy_source = {"Bucket": source_bucket, "Key": source_path} s3_client.copy_object(CopySource=copy_source, Bucket=target_bucket, Key=target_path) return True - + elif provider == "azure": azure_client = get_azure_blob_service_client() if not azure_client: raise Exception("Could not authenticate with Azure Blob Storage") - + source_blob_client = azure_client.get_blob_client(container=source_bucket, blob=source_path) target_blob_client = azure_client.get_blob_client(container=target_bucket, blob=target_path) - + if not source_blob_client.exists(): raise Exception(f"Source file does not exist: azure://{source_bucket}/{source_path}") - + target_blob_client.start_copy_from_url(source_blob_client.url) return True - + return False + def delete_object_unified(provider: str, bucket_or_container: str, path: str) -> bool: """Delete object from any cloud provider.""" if provider == "gcs": storage_client = get_gcs_client() if not storage_client: raise Exception("Could not authenticate with GCP") - + bucket = storage_client.bucket(bucket_or_container) blob = bucket.blob(path) blob.delete() return True - + elif provider == "s3": s3_client = get_s3_client() if not s3_client: raise Exception("Could not authenticate with AWS S3") - + s3_client.delete_object(Bucket=bucket_or_container, Key=path) return True - + elif provider == "azure": azure_client = get_azure_blob_service_client() if not azure_client: raise Exception("Could not authenticate with Azure Blob Storage") - + blob_client = azure_client.get_blob_client(container=bucket_or_container, blob=path) blob_client.delete_blob() return True - + return False + def download_text_unified(provider: str, bucket_or_container: str, path: str) -> str: """Download text content from any cloud provider.""" if provider == "gcs": storage_client = get_gcs_client() if not storage_client: raise Exception("Could not authenticate with GCP") - + bucket = storage_client.bucket(bucket_or_container) blob = bucket.blob(path) - + if not blob.exists(): raise Exception(f"File does not exist: gs://{bucket_or_container}/{path}") - + return blob.download_as_text() - + elif provider == "s3": s3_client = get_s3_client() if not s3_client: raise Exception("Could not authenticate with AWS S3") - + response = s3_client.get_object(Bucket=bucket_or_container, Key=path) - return response['Body'].read().decode('utf-8') - + return response["Body"].read().decode("utf-8") + elif provider == "azure": azure_client = get_azure_blob_service_client() if not azure_client: raise Exception("Could not authenticate with Azure Blob Storage") - + blob_client = azure_client.get_blob_client(container=bucket_or_container, blob=path) - + if not blob_client.exists(): raise Exception(f"File does not exist: azure://{bucket_or_container}/{path}") - - return blob_client.download_blob().readall().decode('utf-8') - + + return blob_client.download_blob().readall().decode("utf-8") + raise Exception(f"Unsupported provider: {provider}") + # Create FastMCP app mcp = FastMCP("CloudStorage") + @mcp.tool(annotations={"readOnlyHint": True, "destructiveHint": False, "idempotentHint": True}) def get_objects(bucket_uri: str) -> str: """Get all objects from a cloud storage bucket/container.""" try: # Parse URI to determine provider and bucket provider, bucket_name, _ = parse_cloud_uri(bucket_uri) - + logger.debug(f"Getting objects from {provider} bucket '{bucket_name}'") - + # Get the raw list of objects objects = list_objects_unified(provider, bucket_name) - + # Loop through and enrich each object with the full file_uri for obj in objects: # This assumes your list_objects_unified returns dicts # and that the object key is stored in the 'name' field. # Adjust 'name' if your key is stored differently (e.g., 'key') - if 'name' in obj: - obj['file_uri'] = f"{provider}://{bucket_name}/{obj['name']}" + if "name" in obj: + obj["file_uri"] = f"{provider}://{bucket_name}/{obj['name']}" else: logger.warning(f"Object {obj} missing 'name' key, cannot construct file_uri") - logger.debug(f"Successfully retrieved and processed {len(objects)} objects from {provider} bucket '{bucket_name}'") - - return json.dumps({ - "provider": provider, - "bucket": bucket_name, - "object_count": len(objects), - "objects": objects - }) - + logger.debug( + f"Successfully retrieved and processed {len(objects)} objects from {provider} bucket '{bucket_name}'" + ) + + return json.dumps( + { + "provider": provider, + "bucket": bucket_name, + "object_count": len(objects), + "objects": objects, + } + ) + except Exception as e: logger.error(f"Error listing objects: {e}") return json.dumps({"error": f"Failed to list objects: {str(e)}"}) -@mcp.tool(annotations={"readOnlyHint": False, "destructiveHint": True, "idempotentHint": False}) + +@mcp.tool( + annotations={ + "readOnlyHint": False, + "destructiveHint": True, + "idempotentHint": False, + } +) def perform_action(file_uri: str, target_uri: str) -> str: """ Move object between cloud storage locations. - + Args: file_uri: Source file URI (example: 'gs://bucket/path/file.txt') target_uri: Target folder URI (example: 'gs://bucket/folder/'). Must end with '/' for folder. """ - + # Validate target is a folder (ends with /) if not target_uri.endswith("/"): return json.dumps({"error": f"Target URI must be a folder path ending with '/': {target_uri}"}) - + try: # Parse source and target URIs source_provider, source_bucket, source_path = parse_cloud_uri(file_uri) target_provider, target_bucket, target_folder = parse_cloud_uri(target_uri) - + # Ensure providers match if source_provider != target_provider: - return json.dumps({"error": f"Cross-provider operations not supported. Source is {source_provider}, target is {target_provider}"}) - + return json.dumps( + { + "error": f"Cross-provider operations not supported. Source is {source_provider}, target is {target_provider}" + } + ) + # Extract filename from source path filename = os.path.basename(source_path) - + # Construct full target blob path (folder + filename) target_path = os.path.join(target_folder, filename).replace("\\", "/") - + # Construct full URIs for response full_source_uri = f"{source_provider}://{source_bucket}/{source_path}" full_target_uri = f"{target_provider}://{target_bucket}/{target_path}" - + # Perform copy operation copy_object_unified(source_provider, source_bucket, source_path, target_bucket, target_path) - - result = { - "status": "success" - } - + + result = {"status": "success"} + # If action is move, delete the source delete_object_unified(source_provider, source_bucket, source_path) logger.debug(f"Successfully moved '{full_source_uri}' to '{full_target_uri}'") result["message"] = f"File moved from {full_source_uri} to {full_target_uri}" - + return json.dumps(result) - + except Exception as e: logger.error(f"Error performing move operation: {e}") return json.dumps({"error": f"Failed to move file: {str(e)}"}) + def run_server(): transport = os.getenv("MCP_TRANSPORT", "streamable-http") host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "8000")) mcp.run(transport=transport, host=host, port=port) + if __name__ == "__main__": configured_providers = [] if GCP_SERVICE_ACCOUNT_KEY and GCP_PROJECT_ID: @@ -379,10 +420,10 @@ def run_server(): configured_providers.append("AWS S3") if AZURE_STORAGE_CONNECTION_STRING or (AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY): configured_providers.append("Azure") - + if not configured_providers: logger.warning("No cloud provider credentials configured. Please set up at least one provider.") else: logger.info(f"Configured providers: {', '.join(configured_providers)}") - - run_server() \ No newline at end of file + + run_server() diff --git a/mcp/flight_tool/__init__.py b/mcp/flight_tool/__init__.py index 4660ec57..84fdda15 100644 --- a/mcp/flight_tool/__init__.py +++ b/mcp/flight_tool/__init__.py @@ -13,4 +13,3 @@ # limitations under the License. - diff --git a/mcp/flight_tool/flight_tool.py b/mcp/flight_tool/flight_tool.py index 10ec8aab..de42378f 100644 --- a/mcp/flight_tool/flight_tool.py +++ b/mcp/flight_tool/flight_tool.py @@ -4,53 +4,64 @@ import logging import os import sys -from typing import Any, Dict, List, Optional from datetime import date, datetime +from typing import Any, Dict, List, Optional -from fastmcp import FastMCP from fast_flights import ( FlightData, Passengers, Result, get_flights, - search_airport as ff_search_airport) - +) +from fast_flights import ( + search_airport as ff_search_airport, +) +from fastmcp import FastMCP mcp = FastMCP("Flights") logger = logging.getLogger(__name__) -logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"), stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + stream=sys.stdout, + format="%(levelname)s: %(message)s", +) def _result_to_dict(r: Result) -> List[Dict[str, Any]]: - flights = getattr(r, 'flights', []) + flights = getattr(r, "flights", []) if not flights: - return [{ - "id": None, - "airline": None, - "price": "N/A", - "price_value": None, - "duration_minutes": None, - "stops": None, - "departure": None, - "arrival": None, - }] - + return [ + { + "id": None, + "airline": None, + "price": "N/A", + "price_value": None, + "duration_minutes": None, + "stops": None, + "departure": None, + "arrival": None, + } + ] + flight_results = [] for flight in flights: - flight_results.append({ - "id": getattr(flight, 'name', None), - "airline": getattr(flight, 'name', None), - "price_value": getattr(r, 'current_price', None), - "duration_minutes": getattr(flight, 'duration', None), - "stops": getattr(flight, 'stops', None), - "departure": getattr(flight, 'departure', None), - "arrival": getattr(flight, 'arrival', None), - "is_best": getattr(flight, 'is_best', False), - "delay": getattr(flight, 'delay', None), - }) - + flight_results.append( + { + "id": getattr(flight, "name", None), + "airline": getattr(flight, "name", None), + "price_value": getattr(r, "current_price", None), + "duration_minutes": getattr(flight, "duration", None), + "stops": getattr(flight, "stops", None), + "departure": getattr(flight, "departure", None), + "arrival": getattr(flight, "arrival", None), + "is_best": getattr(flight, "is_best", False), + "delay": getattr(flight, "delay", None), + } + ) + return flight_results + def _parse_iso_date(d: str) -> Optional[date]: if not d: return None @@ -85,13 +96,17 @@ def _coerce_int(val: Any, name: str, default: int) -> tuple[int, Optional[str]]: except Exception: return default, f"Invalid integer value for '{name}': {val!r}" else: - return default, f"Invalid type for '{name}': expected int or str, got {type(val).__name__}" + return ( + default, + f"Invalid type for '{name}': expected int or str, got {type(val).__name__}", + ) if i < 0: return default, f"'{name}' must be >= 0" return i, None + @mcp.tool(annotations={"readOnlyHint": True, "destructiveHint": False, "idempotentHint": True}) def search_airports(query: str, limit: int = 10) -> str: """Search for airports by name or code. @@ -165,113 +180,165 @@ def search_flights( # Validate dates are not in the past dep_date_obj = _parse_iso_date(departure_date) if dep_date_obj is None: - return json.dumps({"error": "Invalid departure_date format. Use YYYY-MM-DD", "departure_date": departure_date}) + return json.dumps( + { + "error": "Invalid departure_date format. Use YYYY-MM-DD", + "departure_date": departure_date, + } + ) if _date_in_past(dep_date_obj): - return json.dumps({"error": "departure_date cannot be in the past", "departure_date": departure_date}) + return json.dumps( + { + "error": "departure_date cannot be in the past", + "departure_date": departure_date, + } + ) ret_date_obj = None if return_date: ret_date_obj = _parse_iso_date(return_date) if ret_date_obj is None: - return json.dumps({"error": "Invalid return_date format. Use YYYY-MM-DD", "return_date": return_date}) + return json.dumps( + { + "error": "Invalid return_date format. Use YYYY-MM-DD", + "return_date": return_date, + } + ) if _date_in_past(ret_date_obj): - return json.dumps({"error": "return_date cannot be in the past", "return_date": return_date}) + return json.dumps( + { + "error": "return_date cannot be in the past", + "return_date": return_date, + } + ) # Ensure return date is not before departure if ret_date_obj < dep_date_obj: - return json.dumps({"error": "return_date cannot be before departure_date", "departure_date": departure_date, "return_date": return_date}) - + return json.dumps( + { + "error": "return_date cannot be before departure_date", + "departure_date": departure_date, + "return_date": return_date, + } + ) + flight_data_kwargs = { "date": departure_date, "from_airport": from_airport, "to_airport": to_airport, } - + if airlines: airline_list = [airline.strip().upper() for airline in airlines.split(",")] flight_data_kwargs["airlines"] = airline_list - + if max_stops is not None: flight_data_kwargs["max_stops"] = max_stops - + flight_data_list = [FlightData(**flight_data_kwargs)] - + # Add return flight for round-trip if return_date: return_flight_kwargs = flight_data_kwargs.copy() - return_flight_kwargs.update({ - "date": return_date, - "from_airport": to_airport, - "to_airport": from_airport, - }) + return_flight_kwargs.update( + { + "date": return_date, + "from_airport": to_airport, + "to_airport": from_airport, + } + ) flight_data_list.append(FlightData(**return_flight_kwargs)) trip_type = "round-trip" else: trip_type = "one-way" - + seat_mapping = { "economy": "economy", - "premium_economy": "premium_economy", + "premium_economy": "premium_economy", "business": "business", - "first": "first" + "first": "first", } seat_type = seat_mapping.get(cabin, "economy") - + total_passengers = adults + children + infants_in_seat + infants_on_lap if total_passengers > 9: - return json.dumps({ - "error": "Total passengers cannot exceed 9", - "request": { - "from_airport": from_airport, - "to_airport": to_airport, - "departure_date": departure_date, - "return_date": return_date, - "cabin": cabin, - "adults": adults, - "children": children, - "infants_in_seat": infants_in_seat, - "infants_on_lap": infants_on_lap, - "airlines": airlines, - "max_stops": max_stops, + return json.dumps( + { + "error": "Total passengers cannot exceed 9", + "request": { + "from_airport": from_airport, + "to_airport": to_airport, + "departure_date": departure_date, + "return_date": return_date, + "cabin": cabin, + "adults": adults, + "children": children, + "infants_in_seat": infants_in_seat, + "infants_on_lap": infants_on_lap, + "airlines": airlines, + "max_stops": max_stops, + }, } - }) - + ) + if infants_on_lap > adults: - return json.dumps({ - "error": "Must have at least one adult per infant on lap", - "request": { - "from_airport": from_airport, - "to_airport": to_airport, - "departure_date": departure_date, - "return_date": return_date, - "cabin": cabin, - "adults": adults, - "children": children, - "infants_in_seat": infants_in_seat, - "infants_on_lap": infants_on_lap, - "airlines": airlines, - "max_stops": max_stops, + return json.dumps( + { + "error": "Must have at least one adult per infant on lap", + "request": { + "from_airport": from_airport, + "to_airport": to_airport, + "departure_date": departure_date, + "return_date": return_date, + "cabin": cabin, + "adults": adults, + "children": children, + "infants_in_seat": infants_in_seat, + "infants_on_lap": infants_on_lap, + "airlines": airlines, + "max_stops": max_stops, + }, } - }) - + ) + passengers = Passengers( adults=adults, children=children, infants_in_seat=infants_in_seat, - infants_on_lap=infants_on_lap + infants_on_lap=infants_on_lap, ) - + logger.debug(f"Searching flights: {flight_data_list}") try: result: Result = get_flights( - flight_data=flight_data_list, - trip=trip_type, - seat=seat_type, - passengers=passengers, - fetch_mode="fallback" - ) + flight_data=flight_data_list, + trip=trip_type, + seat=seat_type, + passengers=passengers, + fetch_mode="fallback", + ) except Exception: - return json.dumps({ - "error": "An error occurred while fetching flight data, there may be no available flights for the given parameters.", + return json.dumps( + { + "error": "An error occurred while fetching flight data, there may be no available flights for the given parameters.", + "request": { + "from_airport": from_airport, + "to_airport": to_airport, + "departure_date": departure_date, + "return_date": return_date, + "cabin": cabin, + "adults": adults, + "children": children, + "infants_in_seat": infants_in_seat, + "infants_on_lap": infants_on_lap, + "airlines": airlines, + "max_stops": max_stops, + }, + } + ) + + summary: List[Dict[str, Any]] = _result_to_dict(result) + return json.dumps( + { "request": { "from_airport": from_airport, "to_airport": to_airport, @@ -284,27 +351,11 @@ def search_flights( "infants_on_lap": infants_on_lap, "airlines": airlines, "max_stops": max_stops, - } - }) - - summary: List[Dict[str, Any]] = _result_to_dict(result) - return json.dumps({ - "request": { - "from_airport": from_airport, - "to_airport": to_airport, - "departure_date": departure_date, - "return_date": return_date, - "cabin": cabin, - "adults": adults, - "children": children, - "infants_in_seat": infants_in_seat, - "infants_on_lap": infants_on_lap, - "airlines": airlines, - "max_stops": max_stops, - }, - "count": len(summary), - "summary": summary - }) + }, + "count": len(summary), + "summary": summary, + } + ) def run_server(): @@ -317,5 +368,3 @@ def run_server(): if __name__ == "__main__": run_server() - - diff --git a/mcp/image_tool/image_tool.py b/mcp/image_tool/image_tool.py index 0ab8fb08..d876c2f4 100644 --- a/mcp/image_tool/image_tool.py +++ b/mcp/image_tool/image_tool.py @@ -3,13 +3,18 @@ import base64 import logging import os -import requests import sys + +import requests from fastmcp import FastMCP mcp = FastMCP("Image") logger = logging.getLogger(__name__) -logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"), stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + stream=sys.stdout, + format="%(levelname)s: %(message)s", +) @mcp.tool(annotations={"readOnlyHint": True, "destructiveHint": False, "idempotentHint": True}) @@ -23,7 +28,7 @@ def get_image(width: int, height: int) -> dict: Returns a dict containing: - image_base64: base64-encoded image data (string) - url: the source URL of the image (string) - + Example return value: {"image_base64": "/9j/4AAQSkZJRg...", "url": "https://picsum.photos/200/300"} """ diff --git a/mcp/movie_tool/movie_tool.py b/mcp/movie_tool/movie_tool.py index 447e62bd..cbe78aea 100644 --- a/mcp/movie_tool/movie_tool.py +++ b/mcp/movie_tool/movie_tool.py @@ -1,19 +1,25 @@ -import os import json -import requests -import sys -from fastmcp import FastMCP import logging +import os +import sys from typing import Any +import requests +from fastmcp import FastMCP + logger = logging.getLogger(__name__) -logging.basicConfig(level=os.getenv("LOG_LEVEL", "DEBUG"), stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "DEBUG"), + stream=sys.stdout, + format="%(levelname)s: %(message)s", +) logging.getLogger("urllib3").setLevel(logging.INFO) OMDB_API_KEY = os.getenv("OMDB_API_KEY") mcp = FastMCP("Movie Review") + def _fetch_json(params: dict[str, Any], timeout: int = 10) -> dict[str, Any]: """ Helper to perform a GET request and parse the JSON response from the OMDb API. @@ -36,22 +42,24 @@ def _fetch_json(params: dict[str, Any], timeout: int = 10) -> dict[str, Any]: logger.error("Error fetching data: %s", e) return {"Error": "Error fetching data"} + @mcp.tool(annotations={"readOnlyHint": True, "destructiveHint": False, "idempotentHint": True}) def get_full_plot(movie_title: str) -> str: """Get full plot summary of a movie from OMDb API.""" - + logger.debug("Requesting OMDb with t=%s plot=%s", movie_title, "full") params = {"t": movie_title, "plot": "full"} data = _fetch_json(params=params) - + if "Error" in data: return data["Error"] - + if "Response" in data and data["Response"] == "True" and "Plot" in data: return data["Plot"] - + return "Movie not found" + @mcp.tool(annotations={"readOnlyHint": True, "destructiveHint": False, "idempotentHint": True}) def get_movie_details(movie_title: str) -> str: """Get full details (awards, actors, short plot, and ratings, etc.) of a movie from OMDb API.""" @@ -59,7 +67,7 @@ def get_movie_details(movie_title: str) -> str: logger.debug("Requesting OMDb with t=%s plot=%s", movie_title, "short") params = {"t": movie_title, "plot": "short"} data = _fetch_json(params=params) - + if "Error" in data: return data["Error"] @@ -70,6 +78,7 @@ def get_movie_details(movie_title: str) -> str: return "Movie not found" + # host can be specified with HOST env variable # transport can be specified with MCP_TRANSPORT env variable (defaults to streamable-http) def run_server(): @@ -79,6 +88,7 @@ def run_server(): port = int(os.getenv("PORT", "8000")) mcp.run(transport=transport, host=host, port=port) + if __name__ == "__main__": if OMDB_API_KEY is None: logger.warning("Please configure the OMDB_API_KEY environment variable before running the server") diff --git a/mcp/reservation_tool/providers/base.py b/mcp/reservation_tool/providers/base.py index fca04051..9bcb61b8 100644 --- a/mcp/reservation_tool/providers/base.py +++ b/mcp/reservation_tool/providers/base.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from typing import List, Optional -from schemas import Restaurant, AvailabilitySlot, Reservation, CancellationReceipt + +from schemas import AvailabilitySlot, CancellationReceipt, Reservation, Restaurant class ReservationProvider(ABC): diff --git a/mcp/reservation_tool/providers/mock.py b/mcp/reservation_tool/providers/mock.py index 93f439b8..8e7c67d2 100644 --- a/mcp/reservation_tool/providers/mock.py +++ b/mcp/reservation_tool/providers/mock.py @@ -2,10 +2,17 @@ import hashlib import logging -from datetime import datetime, timedelta, timezone -from typing import List, Optional, Dict +from datetime import datetime, timezone +from typing import Dict, List, Optional + from providers.base import ReservationProvider -from schemas import Restaurant, Location, AvailabilitySlot, Reservation, CancellationReceipt +from schemas import ( + AvailabilitySlot, + CancellationReceipt, + Location, + Reservation, + Restaurant, +) logger = logging.getLogger(__name__) @@ -261,7 +268,16 @@ def check_availability( # Generate slots for lunch (11:30-14:00) and dinner (17:00-21:00) slots = [] lunch_times = ["11:30", "12:00", "12:30", "13:00", "13:30"] - dinner_times = ["17:00", "17:30", "18:00", "18:30", "19:00", "19:30", "20:00", "20:30"] + dinner_times = [ + "17:00", + "17:30", + "18:00", + "18:30", + "19:00", + "19:30", + "20:00", + "20:30", + ] all_times = lunch_times + dinner_times diff --git a/mcp/reservation_tool/reservation_tool.py b/mcp/reservation_tool/reservation_tool.py index e77ff342..c01d9d1e 100644 --- a/mcp/reservation_tool/reservation_tool.py +++ b/mcp/reservation_tool/reservation_tool.py @@ -4,13 +4,13 @@ and managing reservations through a provider abstraction layer. """ +import json +import logging import os import sys -import logging -import json from typing import Optional -from fastmcp import FastMCP +from fastmcp import FastMCP from providers import MockProvider, ReservationProvider # Setup logging @@ -18,7 +18,7 @@ logging.basicConfig( level=os.getenv("LOG_LEVEL", "INFO"), stream=sys.stdout, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Initialize provider @@ -92,7 +92,9 @@ def check_availability( Returns: JSON string containing list of available time slots """ - logger.info(f"check_availability called: restaurant={restaurant_id}, date_time={date_time}, party_size={party_size}") + logger.info( + f"check_availability called: restaurant={restaurant_id}, date_time={date_time}, party_size={party_size}" + ) try: slots = provider.check_availability( @@ -114,7 +116,13 @@ def check_availability( return json.dumps({"error": str(e)}) -@mcp.tool(annotations={"readOnlyHint": False, "destructiveHint": False, "idempotentHint": True}) +@mcp.tool( + annotations={ + "readOnlyHint": False, + "destructiveHint": False, + "idempotentHint": True, + } +) def place_reservation( restaurant_id: str, date_time: str, @@ -241,7 +249,9 @@ def run_server(): port = int(os.getenv("PORT", "8000")) logger.info(f"Starting Restaurant Reservation MCP Server on {host}:{port} with transport={transport}") - logger.info(f"Registered tools: search_restaurants, check_availability, place_reservation, cancel_reservation, list_reservations") + logger.info( + "Registered tools: search_restaurants, check_availability, place_reservation, cancel_reservation, list_reservations" + ) mcp.run(transport=transport, host=host, port=port) diff --git a/mcp/reservation_tool/schemas.py b/mcp/reservation_tool/schemas.py index fe4ee0b5..372c1cb0 100644 --- a/mcp/reservation_tool/schemas.py +++ b/mcp/reservation_tool/schemas.py @@ -1,6 +1,7 @@ """Data models for the Restaurant Reservation MCP server.""" from typing import Optional + from pydantic import BaseModel, Field diff --git a/mcp/reservation_tool/tests/test_reservation_tool.py b/mcp/reservation_tool/tests/test_reservation_tool.py index defef637..b20b9160 100644 --- a/mcp/reservation_tool/tests/test_reservation_tool.py +++ b/mcp/reservation_tool/tests/test_reservation_tool.py @@ -2,13 +2,14 @@ import sys from pathlib import Path + import pytest # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from providers.mock import MockProvider -from schemas import Restaurant, AvailabilitySlot, Reservation +from schemas import AvailabilitySlot, Reservation, Restaurant class TestMockProvider: @@ -45,11 +46,7 @@ def test_search_restaurants_not_found(self, provider): def test_check_availability(self, provider): """Test checking availability for a restaurant.""" - slots = provider.check_availability( - restaurant_id="rest_001", - date_time="2025-03-15T12:00:00", - party_size=4 - ) + slots = provider.check_availability(restaurant_id="rest_001", date_time="2025-03-15T12:00:00", party_size=4) assert len(slots) > 0 assert all(isinstance(s, AvailabilitySlot) for s in slots) # Slots should include both lunch and dinner times @@ -61,7 +58,7 @@ def test_check_availability_invalid_restaurant(self, provider): provider.check_availability( restaurant_id="invalid_id", date_time="2025-03-15T12:00:00", - party_size=4 + party_size=4, ) def test_place_reservation(self, provider): @@ -73,7 +70,7 @@ def test_place_reservation(self, provider): name="John Doe", phone="+1-555-123-4567", email="john@example.com", - notes="Window seat preferred" + notes="Window seat preferred", ) assert isinstance(reservation, Reservation) assert reservation.restaurant_id == "rest_001" @@ -91,7 +88,7 @@ def test_place_reservation_idempotent(self, provider): party_size=4, name="John Doe", phone="+1-555-123-4567", - email="john@example.com" + email="john@example.com", ) # Place duplicate reservation @@ -101,7 +98,7 @@ def test_place_reservation_idempotent(self, provider): party_size=4, name="John Doe", phone="+1-555-123-4567", - email="john@example.com" + email="john@example.com", ) # Should return same reservation @@ -117,7 +114,7 @@ def test_place_reservation_invalid_restaurant(self, provider): party_size=4, name="John Doe", phone="+1-555-123-4567", - email="john@example.com" + email="john@example.com", ) def test_list_reservations(self, provider): @@ -129,7 +126,7 @@ def test_list_reservations(self, provider): party_size=4, name="John Doe", phone="+1-555-123-4567", - email="john@example.com" + email="john@example.com", ) # List by email @@ -151,14 +148,11 @@ def test_cancel_reservation(self, provider): party_size=4, name="John Doe", phone="+1-555-123-4567", - email="john@example.com" + email="john@example.com", ) # Cancel it - receipt = provider.cancel_reservation( - reservation_id=reservation.id, - reason="Change of plans" - ) + receipt = provider.cancel_reservation(reservation_id=reservation.id, reason="Change of plans") assert receipt.reservation_id == reservation.id assert receipt.reason == "Change of plans" @@ -194,8 +188,8 @@ def test_restaurant_validation(self): address="123 Main St", city="Boston", state="MA", - postal_code="02101" - ) + postal_code="02101", + ), ) assert restaurant.id == "test_001" @@ -214,17 +208,13 @@ def test_restaurant_validation(self): address="123 Main St", city="Boston", state="MA", - postal_code="02101" - ) + postal_code="02101", + ), ) def test_availability_slot_validation(self): """Test AvailabilitySlot schema validation.""" - slot = AvailabilitySlot( - time="2025-03-15T19:00:00", - max_party_size=8, - available=True - ) + slot = AvailabilitySlot(time="2025-03-15T19:00:00", max_party_size=8, available=True) assert slot.available is True # Invalid party size @@ -232,5 +222,5 @@ def test_availability_slot_validation(self): AvailabilitySlot( time="2025-03-15T19:00:00", max_party_size=0, # Invalid: must be >= 1 - available=True + available=True, ) diff --git a/mcp/shopping_tool/__init__.py b/mcp/shopping_tool/__init__.py index 4c44d91f..c70c6348 100644 --- a/mcp/shopping_tool/__init__.py +++ b/mcp/shopping_tool/__init__.py @@ -1,6 +1,5 @@ """Shopping Agent MCP Tool""" -from .shopping_agent import recommend_products, search_products, run_server +from .shopping_agent import recommend_products, run_server, search_products __all__ = ["recommend_products", "search_products", "run_server"] - diff --git a/mcp/shopping_tool/shopping_agent.py b/mcp/shopping_tool/shopping_agent.py index 2c98baa9..45e8e04f 100644 --- a/mcp/shopping_tool/shopping_agent.py +++ b/mcp/shopping_tool/shopping_agent.py @@ -1,11 +1,12 @@ """Shopping Agent MCP Tool - Uses SerpAPI for product search""" import argparse -import os -import sys import json import logging +import os +import sys from typing import Any, Dict, Optional, Union + from fastmcp import FastMCP from serpapi import GoogleSearch @@ -19,6 +20,7 @@ def _env_flag(name: str, default: str = "false") -> bool: value = default return value.strip().lower() in {"1", "true", "yes", "on"} + # Environment variable for API key SERPAPI_API_KEY = os.getenv("SERPAPI_API_KEY") @@ -51,14 +53,14 @@ def _env_flag(name: str, default: str = "false") -> bool: def recommend_products(query: str, max_results: int = 10) -> str: """ Recommend products based on natural language query (e.g., "good curtains under $40") - + This tool searches Google Shopping via SerpAPI and returns structured product data including titles, prices, and descriptions. - + Args: query: Natural language product request max_results: Maximum number of product recommendations to return (default 10, max 20) - + Returns: JSON string containing product search results with names, prices, descriptions, and links. """ @@ -68,13 +70,13 @@ def recommend_products(query: str, max_results: int = 10) -> str: if len(query) > 256: return json.dumps({"error": "Query is too long (max 256 characters)."}) logger.info(f"Searching products for query: '{query}'") - + if not SERPAPI_API_KEY: return json.dumps({"error": "SERPAPI_API_KEY not configured"}) - + # Limit max_results max_results = min(max_results, 20) - + try: # Configure SerpAPI Google Shopping search params = { @@ -84,18 +86,18 @@ def recommend_products(query: str, max_results: int = 10) -> str: "google_domain": "google.com", "gl": "us", "hl": "en", - "num": max_results + "num": max_results, } - + logger.debug(f"Searching with params: {json.dumps(params, default=str)}") search = GoogleSearch(params) results = search.get_dict() - + if "error" in results: return json.dumps({"error": results["error"]}) - + shopping_results = results.get("shopping_results", []) - + # Format products products = [] for item in shopping_results: @@ -107,23 +109,26 @@ def recommend_products(query: str, max_results: int = 10) -> str: "thumbnail": item.get("thumbnail"), "source": item.get("source"), "rating": item.get("rating"), - "reviews": item.get("reviews") + "reviews": item.get("reviews"), } products.append(product) - + # Fallback to regular search if no shopping results found if not products and "organic_results" in results: logger.info("No shopping results found, falling back to organic results") # This might happen if we switch engine to 'google' or if shopping has no results # But with engine='google_shopping', we should get shopping_results pass - - return json.dumps({ - "query": query, - "products": products[:max_results], - "count": len(products[:max_results]) - }, indent=2) - + + return json.dumps( + { + "query": query, + "products": products[:max_results], + "count": len(products[:max_results]), + }, + indent=2, + ) + except Exception as e: logger.error(f"Error in recommend_products: {e}", exc_info=True) return json.dumps({"error": str(e)}) @@ -133,11 +138,11 @@ def recommend_products(query: str, max_results: int = 10) -> str: def search_products(query: str, max_results: int = 10) -> str: """ Search for products using standard Google Search (internal tool) - + Args: query: Product search query max_results: Maximum number of results to return (default 10, max 100) - + Returns: JSON string containing search results """ @@ -147,13 +152,13 @@ def search_products(query: str, max_results: int = 10) -> str: if len(query) > 256: return json.dumps({"error": "Query parameter is too long (max 256 characters)."}) logger.info(f"Searching products for query: '{query}'") - + if not SERPAPI_API_KEY: return json.dumps({"error": "SERPAPI_API_KEY not configured"}) - + # Limit max_results max_results = min(max_results, 100) - + try: # Use standard Google Search for broader context params = { @@ -163,12 +168,12 @@ def search_products(query: str, max_results: int = 10) -> str: "google_domain": "google.com", "gl": "us", "hl": "en", - "num": max_results + "num": max_results, } - + search = GoogleSearch(params) results = search.get_dict() - + if "error" in results: return json.dumps({"error": results["error"]}) @@ -180,7 +185,7 @@ def search_products(query: str, max_results: int = 10) -> str: }, indent=2, ) - + except Exception as e: logger.error(f"Error in search_products: {e}", exc_info=True) return json.dumps({"error": str(e)}) @@ -227,6 +232,7 @@ def run_server( # Attach agent card route if FastMCP exposes the FastAPI app app = getattr(mcp, "app", None) if app: + @app.get("/.well-known/agent.json") def agent_card() -> Dict[str, Any]: return AGENT_CARD @@ -294,10 +300,10 @@ def main() -> int: if SERPAPI_API_KEY is None: logger.error("Please configure the SERPAPI_API_KEY environment variable before running the server") return 1 - + logger.info("Starting Shopping Agent MCP Server with SerpAPI") logger.info("Note: This server provides search results. The calling agent provides reasoning.") - + run_server( transport=args.transport, host=args.host, diff --git a/mcp/slack_tool/__init__.py b/mcp/slack_tool/__init__.py index 9ebcf029..357cd0e2 100644 --- a/mcp/slack_tool/__init__.py +++ b/mcp/slack_tool/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/mcp/slack_tool/slack_tool.py b/mcp/slack_tool/slack_tool.py index 1f9ca9fb..9b1adeda 100644 --- a/mcp/slack_tool/slack_tool.py +++ b/mcp/slack_tool/slack_tool.py @@ -1,18 +1,24 @@ +import logging import os import sys -import logging -from typing import List, Dict, Any +from typing import Any, Dict, List + from fastmcp import FastMCP from slack_sdk import WebClient from slack_sdk.errors import SlackApiError logger = logging.getLogger(__name__) -logging.basicConfig(level=os.getenv("LOG_LEVEL", "DEBUG"), stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "DEBUG"), + stream=sys.stdout, + format="%(levelname)s: %(message)s", +) # setup slack client SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") ADMIN_SLACK_BOT_TOKEN = os.getenv("ADMIN_SLACK_BOT_TOKEN") + def slack_client_from_bot_token(bot_token): try: slack_client = WebClient(token=bot_token) @@ -27,6 +33,7 @@ def slack_client_from_bot_token(bot_token): logger.exception(f"An unexpected error occurred during Slack client initialization: {e}") return None + def get_slack_client(): if ADMIN_SLACK_BOT_TOKEN: return slack_client_from_bot_token(ADMIN_SLACK_BOT_TOKEN) @@ -36,16 +43,17 @@ def get_slack_client(): # Create FastMCP app mcp = FastMCP("Slack") + @mcp.tool() def get_channels() -> List[Dict[str, Any]]: """ Lists all public and private slack channels you have access to. """ - logger.debug(f"Called get_channels tool") + logger.debug("Called get_channels tool") slack_client = get_slack_client() if slack_client is None: - return [{"error": f"Could not start slack client. Check the configured bot token"}] + return [{"error": "Could not start slack client. Check the configured bot token"}] try: # Call the conversations_list method to get public channels @@ -54,7 +62,11 @@ def get_channels() -> List[Dict[str, Any]]: # We'll just return some key information for each channel logger.debug(f"Successful get_channels call: {channels}") return [ - {"id": c["id"], "name": c["name"], "purpose": c.get("purpose", {}).get("value", "")} + { + "id": c["id"], + "name": c["name"], + "purpose": c.get("purpose", {}).get("value", ""), + } for c in channels ] except SlackApiError as e: @@ -65,6 +77,7 @@ def get_channels() -> List[Dict[str, Any]]: logger.exception(f"Unexpected error occurred: {e}") return [{"error": f"An unexpected error occurred: {e}"}] + @mcp.tool() def get_channel_history(channel_id: str, limit: int = 20) -> List: """ @@ -78,22 +91,22 @@ def get_channel_history(channel_id: str, limit: int = 20) -> List: slack_client = get_slack_client() if slack_client is None: - return [{"error": f"Could not start slack client. Check the configured bot token"}] + return [{"error": "Could not start slack client. Check the configured bot token"}] try: # Call the Slack API to list conversations the bot is part of. - response = slack_client.conversations_history( - channel=channel_id, - limit=limit - ) + response = slack_client.conversations_history(channel=channel_id, limit=limit) logger.debug(f"Successful get_channel_history call: {response}") - return response.get("messages",) + return response.get( + "messages", + ) except SlackApiError as e: # Handle API errors and return a descriptive message return [{"error": f"Slack API Error: {e.response['error']}"}] except Exception as e: return [{"error": f"An unexpected error occurred: {e}"}] + # host can be specified with HOST env variable # transport can be specified with MCP_TRANSPORT env variable (defaults to streamable-http) def run_server(): @@ -102,12 +115,15 @@ def run_server(): port = int(os.getenv("PORT", "8000")) mcp.run(transport=transport, host=host, port=port) + if __name__ == "__main__": if SLACK_BOT_TOKEN is None: logger.warning("Please configure the SLACK_BOT_TOKEN environment variable before running the server") else: if ADMIN_SLACK_BOT_TOKEN: - logger.info("Both SLACK_BOT_TOKEN and ADMIN_SLACK_BOT_TOKEN configured; ADMIN_SLACK_BOT_TOKEN takes precedence") + logger.info( + "Both SLACK_BOT_TOKEN and ADMIN_SLACK_BOT_TOKEN configured; ADMIN_SLACK_BOT_TOKEN takes precedence" + ) else: logger.info("Using SLACK_BOT_TOKEN for all Slack API calls") logger.info("Starting Slack MCP Server") diff --git a/mcp/weather_tool/__init__.py b/mcp/weather_tool/__init__.py index 9ebcf029..357cd0e2 100644 --- a/mcp/weather_tool/__init__.py +++ b/mcp/weather_tool/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/mcp/weather_tool/weather_tool.py b/mcp/weather_tool/weather_tool.py index d2722015..d99cd8fc 100644 --- a/mcp/weather_tool/weather_tool.py +++ b/mcp/weather_tool/weather_tool.py @@ -3,13 +3,19 @@ import json import logging import os -import requests import sys + +import requests from fastmcp import FastMCP mcp = FastMCP("Weather") logger = logging.getLogger(__name__) -logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"), stream=sys.stdout, format='%(levelname)s: %(message)s') +logging.basicConfig( + level=os.getenv("LOG_LEVEL", "INFO"), + stream=sys.stdout, + format="%(levelname)s: %(message)s", +) + @mcp.tool(annotations={"readOnlyHint": True, "destructiveHint": False, "idempotentHint": True}) def get_weather(city: str) -> str: @@ -19,7 +25,7 @@ def get_weather(city: str) -> str: params = {"name": city, "count": 1} response = requests.get(base_url, params=params, timeout=10) data = response.json() - if not data or not "results" in data: + if not data or "results" not in data: return f"City {city} not found" latitude = data["results"][0]["latitude"] longitude = data["results"][0]["longitude"] @@ -29,13 +35,14 @@ def get_weather(city: str) -> str: "latitude": latitude, "longitude": longitude, "temperature_unit": "fahrenheit", - "current_weather": True + "current_weather": True, } weather_response = requests.get(weather_url, params=weather_params, timeout=10) weather_data = weather_response.json() return json.dumps(weather_data["current_weather"]) + # host can be specified with HOST env variable # transport can be specified with MCP_TRANSPORT env variable (defaults to streamable-http) def run_server(): @@ -45,5 +52,6 @@ def run_server(): port = int(os.getenv("PORT", "8000")) mcp.run(transport=transport, host=host, port=port) + if __name__ == "__main__": run_server() diff --git a/pyproject.toml b/pyproject.toml index 60feef02..5980f5db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,11 @@ [tool.ruff] line-length = 120 target-version = "py311" +exclude = [".repos"] [tool.ruff.lint] select = ["E", "F", "I", "W"] +ignore = ["E501", "E402", "W291"] [tool.pytest.ini_options] testpaths = ["tests", "mcp/reservation_tool/tests"] diff --git a/tests/a2a/test_contact_extractor.py b/tests/a2a/test_contact_extractor.py index 12d8dabf..2f5df7ff 100644 --- a/tests/a2a/test_contact_extractor.py +++ b/tests/a2a/test_contact_extractor.py @@ -6,7 +6,7 @@ # Mock the marvin dependency before importing sys.modules.setdefault("marvin", MagicMock()) -from agent import TextPart, ExtractionOutcome, _to_text_part, ExtractorAgent +from agent import ExtractionOutcome, ExtractorAgent, TextPart, _to_text_part from pydantic import BaseModel diff --git a/tests/a2a/test_currency_converter.py b/tests/a2a/test_currency_converter.py index eee70a5c..88469491 100644 --- a/tests/a2a/test_currency_converter.py +++ b/tests/a2a/test_currency_converter.py @@ -7,9 +7,14 @@ # Mock heavy dependencies before importing for mod in [ - "langchain_core", "langchain_core.messages", "langchain_core.tools", - "langchain_google_genai", "langchain_openai", - "langgraph", "langgraph.checkpoint", "langgraph.checkpoint.memory", + "langchain_core", + "langchain_core.messages", + "langchain_core.tools", + "langchain_google_genai", + "langchain_openai", + "langgraph", + "langgraph.checkpoint", + "langgraph.checkpoint.memory", "langgraph.prebuilt", ]: sys.modules.setdefault(mod, MagicMock()) diff --git a/tests/a2a/test_weather_service.py b/tests/a2a/test_weather_service.py index fc204fc6..e9c01b95 100644 --- a/tests/a2a/test_weather_service.py +++ b/tests/a2a/test_weather_service.py @@ -16,7 +16,6 @@ sys.modules["weather_service.observability"] = MagicMock() # Now import the configuration module directly — it only needs pydantic_settings -from importlib import import_module import importlib.util # Load configuration.py directly from its file path diff --git a/tests/mcp/test_flight_tool.py b/tests/mcp/test_flight_tool.py index 5e43873f..0334f63e 100644 --- a/tests/mcp/test_flight_tool.py +++ b/tests/mcp/test_flight_tool.py @@ -1,14 +1,14 @@ """Tests for flight_tool MCP server — pure utility functions (isolated from heavy deps).""" import sys -from unittest.mock import MagicMock from datetime import date, timedelta +from unittest.mock import MagicMock # Mock the fastmcp and fast_flights dependencies before importing sys.modules.setdefault("fastmcp", MagicMock()) sys.modules.setdefault("fast_flights", MagicMock()) -from flight_tool import _parse_iso_date, _date_in_past, _coerce_int, _result_to_dict +from flight_tool import _coerce_int, _date_in_past, _parse_iso_date, _result_to_dict class TestParseIsoDate: diff --git a/tests/mcp/test_reservation_schemas.py b/tests/mcp/test_reservation_schemas.py index 747960ea..ca05c0cc 100644 --- a/tests/mcp/test_reservation_schemas.py +++ b/tests/mcp/test_reservation_schemas.py @@ -1,8 +1,7 @@ """Tests for reservation_tool schemas — Pydantic model validation.""" import pytest - -from schemas import Location, Restaurant, AvailabilitySlot, Reservation, CancellationReceipt +from schemas import CancellationReceipt, Location, Restaurant class TestLocation: