diff --git a/adk_a2a_mcp_integration/.env.example b/adk_a2a_mcp_integration/.env.example new file mode 100644 index 0000000..0ef764a --- /dev/null +++ b/adk_a2a_mcp_integration/.env.example @@ -0,0 +1,5 @@ +# To use api key +GOOGLE_GENAI_USE_VERTEXAI=False + +# Google api key, get this from ai studio +GOOGLE_API_KEY= \ No newline at end of file diff --git a/adk_a2a_mcp_integration/Readme.md b/adk_a2a_mcp_integration/Readme.md new file mode 100644 index 0000000..0ae140e --- /dev/null +++ b/adk_a2a_mcp_integration/Readme.md @@ -0,0 +1,19 @@ +# Boilerplate for Agent interaction with ADK, A2A and MCP. + +[To UPDATE] +A simple implementation of agent which has access to some custom tools and an MCP server for exploring Arxiv research papers. It also has the ability to interact with remote agents using A2A Protocol. The agents are built using ADK and can be run using `adk web` or custom streamlit UI. + +```bash +# To start remote agent +python -m remote_agent +``` + +```bash +# To start MCP server +python mcp_server/server.py +``` + +```bash +# To interact with root agent in UI +streamlit run app.py +``` \ No newline at end of file diff --git a/adk_a2a_mcp_integration/app.py b/adk_a2a_mcp_integration/app.py new file mode 100644 index 0000000..3f70c6b --- /dev/null +++ b/adk_a2a_mcp_integration/app.py @@ -0,0 +1,95 @@ +"""Module that interacts with backend agent using streamlit""" +# pylint: disable=line-too-long,invalid-name +import uuid +import asyncio +import streamlit as st +from root_agent.response_manager import ResponseManager + +# --- Page Configuration --- +st.set_page_config(page_title="Agent Response Manager", layout="wide") + +# --- Session State Initialization --- +if "response_manager" not in st.session_state: + st.session_state.response_manager = ResponseManager() +if "session_id" not in st.session_state: + st.session_state.session_id = str(uuid.uuid4()) + st.toast(f"New session created: {st.session_state.session_id}") +if "messages" not in st.session_state: + st.session_state.messages = [] + +# --- Sidebar --- +with st.sidebar: + st.title("Settings") + st.markdown(f"**Session ID:**\n`{st.session_state.session_id}`") + + if st.button("New Session"): + st.session_state.session_id = str(uuid.uuid4()) + st.session_state.messages = [] + st.toast(f"New session started: {st.session_state.session_id}") + st.rerun() + + st.divider() + diagnostics_enabled = st.checkbox("Enable Diagnostics", value=True) + +# --- Main Chat Interface --- +st.title("Agent Chat UI") + +# Display chat history from session state +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + if diagnostics_enabled and "diagnostics" in message and message["diagnostics"]: + with st.expander("View Diagnostics for this response."): + st.json(message["diagnostics"]) + +# --- Coroutine to handle agent invocation and UI updates --- +async def run_agent_and_display(prompt: str): + """Add user message to session state and display it""" + st.session_state.messages.append({"role": "user", "content": prompt}) + with st.chat_message("user"): + st.markdown(prompt) + + # --- Assistant's turn --- + with st.chat_message("assistant"): + final_response_text = "" + diagnostic_events = [] + status_messages = { + "tool_call": "Calling tool... 🛠️", + "tool_response": "Processing tool result... ⚙️", + "finished": "Response generation complete. ✅" + } + + status_placeholder = st.empty() + try: + # Collect all events and the final response from the agent + status_placeholder.text("Thinking... 🤔") + async for event in st.session_state.response_manager.invoke_agent( + session_id=st.session_state.session_id, query=prompt + ): + status_text = status_messages.get(event.get("status"), f"Processing: {event.get('status')}") + status_placeholder.text(status_text) + + diagnostic_events.append(event.get("event")) + if event.get("is_final_response"): + final_response_text = event.get("result", "Sorry, I couldn't generate a response.") + + except Exception as e: + st.error(f"An error occurred: {e}") + final_response_text = "Sorry, I ran into an error." + + status_placeholder.empty() + st.markdown(final_response_text) + + if diagnostics_enabled and diagnostic_events: + with st.expander("View Diagnostics for this response."): + st.json(diagnostic_events) + + st.session_state.messages.append({ + "role": "assistant", + "content": final_response_text, + "diagnostics": diagnostic_events + }) + +# --- Handle User Input --- +if user_input := st.chat_input("What are some trending topics in AI?"): + asyncio.run(run_agent_and_display(user_input)) diff --git a/adk_a2a_mcp_integration/mcp_server/requirements.txt b/adk_a2a_mcp_integration/mcp_server/requirements.txt new file mode 100644 index 0000000..0ca0464 --- /dev/null +++ b/adk_a2a_mcp_integration/mcp_server/requirements.txt @@ -0,0 +1,3 @@ +fastmcp==2.10.6 +arxiv==2.2.0 +pymupdf4llm==0.0.27 \ No newline at end of file diff --git a/adk_a2a_mcp_integration/mcp_server/server.py b/adk_a2a_mcp_integration/mcp_server/server.py new file mode 100644 index 0000000..39367fc --- /dev/null +++ b/adk_a2a_mcp_integration/mcp_server/server.py @@ -0,0 +1,158 @@ +"""Module that implements simple mcp server to query Arxiv research papers collection""" +import os +import logging +import tempfile +import requests +import arxiv +import pymupdf4llm +from fastmcp import FastMCP + +logger = logging.getLogger(__name__) + +mcp = FastMCP("ArxivExplorer") + +@mcp.tool +def search_arxiv(query: str, max_results: int = 5) -> dict: + """ + Searches arXiv for a given query and returns the top papers. + Args: + query: The search keyword or query. + max_results: The maximum number of results to return. + Returns: + A list of dictionaries, where each dictionary represents a paper + and contains its ID, title, summary, authors, and PDF URL. + """ + try: + search = arxiv.Search( + query=query, + max_results=max_results, + sort_by=arxiv.SortCriterion.Relevance + ) + + papers = [] + for result in search.results(): + logger.info(f"{result.title}") + paper_info = { + 'id': result.get_short_id(), + 'title': result.title, + 'summary': result.summary, + 'authors': [author.name for author in result.authors], + 'pdf_url': result.pdf_url + } + papers.append(paper_info) + + return { + "status": "success", + "result": papers + } + except Exception as e: + return { + "status": "error", + "error_message": str(e) + } + +@mcp.tool() +def get_paper_md(paper_id: str) -> dict: + """ + Retrieves the text content of an arXiv paper in Markdown format. + Args: + paper_id: The ID of the paper (e.g., '1706.03762v7'). + Returns: + The text content of the paper as a Markdown string. + Returns an error message if any step fails. + """ + try: + search = arxiv.Search(id_list=[paper_id]) + paper = next(search.results()) + pdf_url = paper.pdf_url + logger.info(f"Found paper: '{paper.title}'") + logger.info(f"Downloading from: {pdf_url}") + + except StopIteration: + return { + "status": "error", + "error_message": f"Paper with ID '{paper_id}' not found on arXiv." + } + except Exception as e: + return { + "status": "error", + "error_message": f"Error searching for the paper: {e}" + } + + try: + # Download the PDF content + response = requests.get(pdf_url) + response.raise_for_status() + pdf_bytes = response.content + logger.info("PDF downloaded successfully.") + + except requests.exceptions.RequestException as e: + return { + "status": "error", + "error_message": f"Error downloading the PDF file, request failure: {e}" + } + except Exception as e: + return { + "status": "error", + "error_message": f"Error downloading the PDF file: {e}" + } + + temp_pdf_path = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: + temp_file.write(pdf_bytes) + temp_pdf_path = temp_file.name # Get the path of the temporary file + + logger.info(f"PDF content written to temporary file: {temp_pdf_path}") + logger.info("Converting PDF to Markdown...") + # Pass the file path to the conversion function + md_text = pymupdf4llm.to_markdown(temp_pdf_path) + + logger.info("Conversion complete.") + if temp_pdf_path and os.path.exists(temp_pdf_path): + os.remove(temp_pdf_path) + logger.info(f"Temporary file {temp_pdf_path} deleted.") + return {"status": "success", "result": md_text} + + except Exception as e: + return { + "status": "error", + "error_message": f"Error converting PDF to Markdown: {e}" + } + +# To pass the paper directly as media to llm, to be used later +# @mcp.tool() +# def get_paper_raw(paper_id: str) -> dict: +# """ +# Retrieves the raw PDF file of an arXiv paper. +# Args: +# paper_id: The ID of the paper (e.g., '1706.03762v7'). +# Returns: +# The raw bytes of the PDF file, or None if the paper is not found. +# """ +# try: +# # Search for the paper by its ID +# search = arxiv.Search(id_list=[paper_id]) +# paper = next(search.results()) + +# # Download the PDF content +# response = requests.get(paper.pdf_url) +# response.raise_for_status() +# return { +# "status": "success", +# "result":response.content +# } +# except StopIteration: +# return { +# "status": "error", +# "error_message": f"Paper with ID {paper_id} not found on arXiv." +# } +# except requests.exceptions.RequestException as e: +# logger.info(f"Error downloading PDF: {e}") +# return {"status": "error", "error_message": f"Error downloading PDF: {e}"} +# except Exception as e: +# logger.info(f"Error: {e}") +# return {"status": "error", "error_message": f"Error: {e}"} + +if __name__ == "__main__": + mcp.run(transport="http", host="127.0.0.1", port=8000, path="/mcp") diff --git a/adk_a2a_mcp_integration/remote_agent/__init__.py b/adk_a2a_mcp_integration/remote_agent/__init__.py new file mode 100644 index 0000000..96d00fd --- /dev/null +++ b/adk_a2a_mcp_integration/remote_agent/__init__.py @@ -0,0 +1,2 @@ +"""Module to start the agent, primarily used by adk web""" +from . import agent diff --git a/adk_a2a_mcp_integration/remote_agent/__main__.py b/adk_a2a_mcp_integration/remote_agent/__main__.py new file mode 100644 index 0000000..a48c5e2 --- /dev/null +++ b/adk_a2a_mcp_integration/remote_agent/__main__.py @@ -0,0 +1,56 @@ +"""Module that exposes the remote agent as a server, sharing it's capabilities +and methods to invoke it""" +import logging +import uvicorn +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCapabilities, AgentCard, AgentSkill +from remote_agent.task_manager import BasicSearchAgentExecutor + +logger = logging.getLogger(__name__) + +AGENT_URL = "http://localhost:8090/" + +def start_remote_agent(): + """Start the remote agent and expose it's capabilities""" + agent_skill = AgentSkill( + id="search_agent", + name="Search Agent", + description="""Agent that can get the latest search results from the + internet using google search and gives accurate results""", + input_modes=["text"], + output_modes=["text"], + tags=["search agent", "google search tool", "web search"], + examples=[ + "What are the latest news in AI?", + "Explain the key difference between Langchain and Langgraph.", + "Who won the last IPL match?"] + ) + + public_agent_card = AgentCard( + name="Search agent", + description="Agent that can search the internet to answer queries.", + url=AGENT_URL, + version="0.0.1", + skills=[agent_skill], + defaultInputModes=['text'], + defaultOutputModes=['text'], + capabilities=AgentCapabilities(streaming=True), + supportsAuthenticatedExtendedCard=False, + ) + + request_handler = DefaultRequestHandler( + agent_executor=BasicSearchAgentExecutor(), + task_store=InMemoryTaskStore(), + ) + server = A2AStarletteApplication( + agent_card=public_agent_card, + http_handler=request_handler + ) + app = server.build() + logger.info("Uvicorn server starting...") + uvicorn.run(app, host="127.0.0.1", port=8090) + +if __name__ == "__main__": + start_remote_agent() diff --git a/adk_a2a_mcp_integration/remote_agent/agent.py b/adk_a2a_mcp_integration/remote_agent/agent.py new file mode 100644 index 0000000..4a3936c --- /dev/null +++ b/adk_a2a_mcp_integration/remote_agent/agent.py @@ -0,0 +1,84 @@ +"""Module that implements the core logic for the search agent""" +import logging + +from google.adk import Runner +from google.adk.agents import Agent +from google.adk.artifacts import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.sessions import InMemorySessionService +from google.adk.tools import google_search +from google.genai import types + +from dotenv import load_dotenv +# To load the google api keys +load_dotenv() + +logger = logging.getLogger(__name__) + +root_agent = Agent( + name="search_agent", + model="gemini-2.0-flash", + description="Agent capable of searching internet to find relevant answers to user questions.", + instruction="""You are an friendly and supportive agent. Your job is to try to answer + the user question using the search tool. Always provide accurate and relevant information.""", + tools=[google_search], +) + +class BasicSearchAgent: + """Class that exposes the basic search agent""" + def __init__(self): + self.agent = root_agent + self.runner = Runner( + app_name=self.agent.name, + agent=self.agent, + artifact_service=InMemoryArtifactService(), + memory_service=InMemoryMemoryService(), + session_service=InMemorySessionService(), + ) + + async def invoke(self, session_id: str, query: str, user_id: str = None): + """Invoke the agent""" + try: + if not user_id: + user_id = "Default User" + session_instance = await self.runner.session_service.get_session( + session_id=session_id, + user_id=user_id, + app_name=self.agent.name + ) + + if not session_instance: + logger.info(f"Creating new session with id: {session_id}") + session_instance = await self.runner.session_service.create_session( + session_id=session_id, + user_id=user_id, + app_name=self.agent.name + ) + + user_content = types.Content( + role="user", parts=[types.Part.from_text(text=query)] + ) + + final_response_text = "" + async for event in self.runner.run_async( + user_id=user_id, + session_id=session_instance.id, + new_message=user_content + ): + # We can break when there's final response, + # but for telemetry usage, the loop must complete + # logger.debug(f"Event: {event}") + if event.is_final_response(): + if event.content and event.content.parts and event.content.parts[-1].text: + final_response_text = event.content.parts[-1].text + logger.info("Loop finished, yielding final response.") + yield { + "status": "success", + "result": final_response_text + } + except Exception as e: + logger.info(f"Error: {e}") + yield { + "status": "error", + "error_message": str(e) + } diff --git a/adk_a2a_mcp_integration/remote_agent/client.py b/adk_a2a_mcp_integration/remote_agent/client.py new file mode 100644 index 0000000..ee6a200 --- /dev/null +++ b/adk_a2a_mcp_integration/remote_agent/client.py @@ -0,0 +1,60 @@ +"""Module to test the remote agent exposed via A2A. Mimic's client side implementation""" +import asyncio +import logging +from typing import Any +from uuid import uuid4 +import httpx +from a2a.client import A2ACardResolver, A2AClient +from a2a.types import MessageSendParams, SendMessageRequest + +logger = logging.getLogger(__name__) + +async def client(): + """Test the agent with a simple client""" + async with httpx.AsyncClient(timeout=120) as httpx_client: + resolver = A2ACardResolver( + httpx_client=httpx_client, + base_url="http://localhost:8090/", + ) + logger.info("Attempting to fetch agent card...") + agent_card = await resolver.get_agent_card() + logger.info('Agent card fetched. Agent card:') + logger.info(agent_card.model_dump_json(indent=2, exclude_none=True)) + + logger.info("Initializing A2A Client") + client_instance = A2AClient( + httpx_client=httpx_client, agent_card=agent_card + ) + logger.info('A2A Client initialized.') + + send_message_payload: dict[str, Any] = { + 'message': { + 'role': 'user', + 'parts': [ + { + 'kind': 'text', + 'text': 'What is model context protocol? Give a brief description.' + } + ], + 'messageId': uuid4().hex, + }, + } + logger.info("Sending test message") + request = SendMessageRequest( + id=str(uuid4()), params=MessageSendParams(**send_message_payload) + ) + + response = await client_instance.send_message(request) + logger.info(response.model_dump(mode='json', exclude_none=True)) + response_dict = response.model_dump(mode='json', exclude_none=True) + agent_response_text = "No text content found in response or an error occurred." + try: + agent_response_text = response_dict['result']['parts'][0]['text'] + except (KeyError, IndexError) as e: + logger.info(f"Error parsing agent response structure: {e}") + + logger.info("\n--- Agent's Final Response ---") + logger.info(agent_response_text) + logger.info("----------------------------") + +asyncio.run(client()) diff --git a/adk_a2a_mcp_integration/remote_agent/task_manager.py b/adk_a2a_mcp_integration/remote_agent/task_manager.py new file mode 100644 index 0000000..627ed9c --- /dev/null +++ b/adk_a2a_mcp_integration/remote_agent/task_manager.py @@ -0,0 +1,49 @@ +"""Module that handles the invocation of agent""" +import logging +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.utils import new_agent_text_message, new_task + +from remote_agent.agent import BasicSearchAgent + +logger = logging.getLogger(__name__) + +class BasicSearchAgentExecutor(AgentExecutor): + """Agent executor that invokes the agent""" + + def __init__(self): + super().__init__() + self._agent = BasicSearchAgent() + + async def execute(self, request_context: RequestContext, event_queue: EventQueue) -> None: + """Invokes the agent with the request context""" + query = request_context.get_user_input() + task = request_context.current_task + if not task: + task = new_task(request_context.message) + await event_queue.enqueue_event(task) + logger.info(f"Creating new task with id: {task.id}.") + + session_id = task.contextId + logger.info(f"Using session id: {session_id}.") + async for event in self._agent.invoke( + session_id=session_id, + query=query + # user_id=request_context.user_id, + ): + if event.get("status") == "success": + await event_queue.enqueue_event( + new_agent_text_message(text=event.get("result")) + ) + elif event.get("status") == "error": + await event_queue.enqueue_event( + new_agent_text_message(text=f"Error: {event.get('error_message')}") + ) + else: + await event_queue.enqueue_event( + new_agent_text_message(text=event.get("result")) + ) + + async def cancel(self, request_context: RequestContext, event_queue: EventQueue) -> None: + """To cancel an ongoing agent execution""" + raise ValueError('cancel not supported at this moment!') diff --git a/adk_a2a_mcp_integration/requirements.txt b/adk_a2a_mcp_integration/requirements.txt new file mode 100644 index 0000000..b411c0e --- /dev/null +++ b/adk_a2a_mcp_integration/requirements.txt @@ -0,0 +1,19 @@ +# mcp server requirements +fastmcp==2.10.6 +arxiv==2.2.0 +pymupdf4llm==0.0.27 + +# general a2a requirements +a2a-sdk==0.2.16 + +# adk requirements +google-adk==1.8.0 +google-genai==1.27.0 +fastapi==0.116.1 + +python-dotenv==1.1.1 +requests==2.32.4 + +uvicorn==0.35.0 + +streamlit==1.47.1 \ No newline at end of file diff --git a/adk_a2a_mcp_integration/root_agent/.env.example b/adk_a2a_mcp_integration/root_agent/.env.example new file mode 100644 index 0000000..0ef764a --- /dev/null +++ b/adk_a2a_mcp_integration/root_agent/.env.example @@ -0,0 +1,5 @@ +# To use api key +GOOGLE_GENAI_USE_VERTEXAI=False + +# Google api key, get this from ai studio +GOOGLE_API_KEY= \ No newline at end of file diff --git a/adk_a2a_mcp_integration/root_agent/__init__.py b/adk_a2a_mcp_integration/root_agent/__init__.py new file mode 100644 index 0000000..96d00fd --- /dev/null +++ b/adk_a2a_mcp_integration/root_agent/__init__.py @@ -0,0 +1,2 @@ +"""Module to start the agent, primarily used by adk web""" +from . import agent diff --git a/adk_a2a_mcp_integration/root_agent/agent.py b/adk_a2a_mcp_integration/root_agent/agent.py new file mode 100644 index 0000000..f99ca7a --- /dev/null +++ b/adk_a2a_mcp_integration/root_agent/agent.py @@ -0,0 +1,35 @@ +"""Core module for agent orchestration""" +from google.adk.agents import Agent +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, StreamableHTTPConnectionParams + +from root_agent.tools import get_current_time, calculate_expression +from root_agent.remote_agent_helpers import list_remote_agents, call_remote_agent + +from dotenv import load_dotenv +# To load the google api keys +load_dotenv() + +simple_mcp_tool = MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url="http://localhost:8000/mcp", + timeout=10, + sse_read_timeout=60 * 5, + terminate_on_close=True, + ), + tool_filter=["search_arxiv","get_paper_md"] +) + +root_agent = Agent( + name="root_agent", + model="gemini-2.0-flash", + description="Agent to answer questions using tools provided.", + instruction="""You are a helpful agent who can answer user questions about current time + and can do calculations. For any queries that require latest/external information, + identify if any remote agents can help with that. Once you found the relevant agents, + use the appropriate tools to get the answer the user query.""", + tools=[get_current_time, + calculate_expression, + simple_mcp_tool, + list_remote_agents, + call_remote_agent], +) diff --git a/adk_a2a_mcp_integration/root_agent/remote_agent_helpers.py b/adk_a2a_mcp_integration/root_agent/remote_agent_helpers.py new file mode 100644 index 0000000..cd271fd --- /dev/null +++ b/adk_a2a_mcp_integration/root_agent/remote_agent_helpers.py @@ -0,0 +1,93 @@ +"""Helper functions to interact with other agents exposed via A2A""" +import logging +from typing import Any +from uuid import uuid4 +import httpx +from a2a.client import A2ACardResolver, A2AClient +from a2a.types import MessageSendParams, SendMessageRequest + +logger = logging.getLogger(__name__) + +remote_agent_cards_cache = [] + +REMOTE_AGENT_URL = "http://localhost:8090/" + +async def list_remote_agents() -> list[dict]: + """Fetches the capabilities of all available remote agents""" + + if remote_agent_cards_cache: + logger.info("Remote agents card cache exists, using cache") + return remote_agent_cards_cache + logger.info("Agent card cache empty, fetching...") + try: + async with httpx.AsyncClient(timeout=120) as httpx_client: + # get only one agent url for now + resolver = A2ACardResolver( + httpx_client=httpx_client, + base_url=REMOTE_AGENT_URL, + ) + logger.info("Attempting to fetch agent card...") + agent_card = await resolver.get_agent_card() + logger.info('Agent card fetched. Agent card:') + logger.info(agent_card.model_dump_json(indent=2, exclude_none=True)) + + remote_agent_cards_cache.append({ + "agent_name": agent_card.name, + "agent_card": agent_card + }) + logger.info("Adding data to cache...") + return remote_agent_cards_cache + except Exception as e: + logger.error("Failed to fetch agent card.") + raise RuntimeError("Failed to fetch the agent card. Unable to proceed") from e + +async def call_remote_agent(query: str, agent_name: str) -> str: + """Call the remote agent with appropriate query""" + + agent_cards = await list_remote_agents() + + agent_card_to_use = None + + for card in agent_cards: + if card.get("agent_name") == agent_name: + agent_card_to_use = card.get("agent_card") + break + if agent_card_to_use is None: + raise ValueError(f"Agent with name '{agent_name}' not found in available agents.") + logger.info("Initializing A2A Client...") + async with httpx.AsyncClient(timeout=120) as httpx_client: + + client = A2AClient( + httpx_client=httpx_client, + agent_card=agent_card_to_use + ) + logger.info("A2A Client Initialized.") + + send_message_payload: dict[str, Any] = { + 'message': { + 'role': 'user', + 'parts': [ + {'kind': 'text', 'text': query} + ], + 'messageId': uuid4().hex, + }, + } + logger.info("Sending query to remote agent...") + request = SendMessageRequest( + id=str(uuid4()), params=MessageSendParams(**send_message_payload) + ) + try: + response = await client.send_message(request) + response_dict = response.model_dump(mode='json', exclude_none=True) + logger.info(f"Response received from remote agent: {response_dict}") + except Exception as e: + logger.error("Failed to send message to remote agent.") + raise RuntimeError("Remote agent call failed.") from e + + agent_response_text = "No text content found in response or an error occurred." + try: + agent_response_text = response_dict['result']['parts'][0]['text'] + except (KeyError, IndexError) as e: + logger.error(f"Error parsing agent response structure: {e}") + + return agent_response_text diff --git a/adk_a2a_mcp_integration/root_agent/response_manager.py b/adk_a2a_mcp_integration/root_agent/response_manager.py new file mode 100644 index 0000000..6a88ff7 --- /dev/null +++ b/adk_a2a_mcp_integration/root_agent/response_manager.py @@ -0,0 +1,115 @@ +"""Module that handles interaction with the agent, maintains session and query passing.""" +import asyncio +import uuid +import logging + +from google.genai import types +from google.adk.runners import Runner +from google.adk.artifacts import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.sessions import InMemorySessionService + +from root_agent.agent import root_agent + +logger = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +class ResponseManager: + """Class that handles the respone management with root agent""" + def __init__(self): + self.agent = root_agent + self.user_id = "u_123" + self.runner = Runner( + app_name=self.agent.name, + agent=self.agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + + async def invoke_agent(self, session_id: str, query: str): + """Invokes the root agent while maintaining sessionid + for continuance""" + try: + logger.info(f"Fetching session data with id: {session_id}") + session = await self.runner.session_service.get_session( + app_name=self.agent.name, + user_id=self.user_id, + session_id=session_id + ) + + if session is None: + logger.info(f"Session doesn't exist, creating new session with id: {session_id}") + session = await self.runner.session_service.create_session( + app_name=self.agent.name, + user_id=self.user_id, + session_id=session_id, + state={} + ) + + logger.info(f"{session.id} Running query: {query}") + + content = types.Content(role="user",parts=[types.Part.from_text(text=query)]) + + final_response_text = "" + async for event in self.runner.run_async( + user_id=self.user_id, + session_id=session.id, + new_message=content + ): + if event.get_function_calls(): + yield { + "is_final_response": False, + "status": "tool_call", + "event": event.model_dump(mode='json', exclude_none=True) + } + if event.get_function_responses(): + yield { + "is_final_response": False, + "status": "tool_response", + "event": event.model_dump(mode='json', exclude_none=True) + } + if event.is_final_response(): + if event.content and event.content.parts and event.content.parts[-1].text: + final_response_text = event.content.parts[-1].text + + logger.info(f"[{session_id}] Final text: {final_response_text}]") + yield { + "is_final_response": True, + "status": "finished", + "result": final_response_text, + "event": event.model_dump(mode='json', exclude_none=True) + } + except Exception as e: + logger.error(f"Error generating response. {str(e)}") + yield { + "is_final_response": True, + "status": "fail", + "result": "No final response received", + "error_message": str(e) + } + +async def test_agent(): + """Utils function to test the agent using response manager""" + response_manager = ResponseManager() + + session_id = str(uuid.uuid4()) + + first_query = "What are some trending topics in AI?" + # first_response = await response_manager.invoke_agent(session_id=session_id, + # query=first_query) + first_response = response_manager.invoke_agent(session_id=session_id, query=first_query) + async for response in first_response: + logger.info(f"Events: {response}") + # logger.info(f"First response: {first_response}") + + # second_query = "What question did I ask you?" + # second_response = await response_manager.invoke_agent(session_id=session_id, + # query=second_query) + # logger.info(f"Second response: {second_response}") + +# asyncio.run(test_agent()) diff --git a/adk_a2a_mcp_integration/root_agent/tools.py b/adk_a2a_mcp_integration/root_agent/tools.py new file mode 100644 index 0000000..32b7fec --- /dev/null +++ b/adk_a2a_mcp_integration/root_agent/tools.py @@ -0,0 +1,79 @@ +"""Module that defines all the tools used by the root agent""" +# pylint: disable=eval-used +import math +import datetime +from zoneinfo import ZoneInfo + +def get_current_time(country: str) -> dict: + """Returns the current time in a specified country. + Args: + country (str): The name of the country for which to retrieve the current time. + Returns: + dict: status and result or error msg. + """ + if country.lower() == "india": + tz_identifier = "Asia/Kolkata" + else: + return { + "status": "error", + "error_message": ( + f"Sorry, I don't have timezone information for {country}." + ), + } + + tz = ZoneInfo(tz_identifier) + now = datetime.datetime.now(tz) + report = ( + f'The current time in {country} is {now.strftime("%Y-%m-%d %H:%M:%S %Z%z")}' + ) + return {"status": "success", "result": report} + +ALLOWED_FUNCTIONS = { + "math": math, + "exp": math.exp, + "log": math.log, + "log10": math.log10, + "sqrt": math.sqrt, + "pi": math.pi, + "e": math.e, + "ceil": math.ceil, + "floor": math.floor, + "round": round, + "factorial": math.factorial, + "isinf": math.isinf, + "isnan": math.isnan, + "isqrt": math.isqrt, +} + +def calculate_expression(expression: str) -> dict: + """Evaluates a mathematical expression and returns the result. + Supports basic operators (+, -, *, /, **, %), mathematical functions + and constants (pi, e). Uses a restricted evaluation context for safe execution. + + Args: + expression: The mathematical expression to evaluate as a string. + Examples: "2 + 2", "sqrt(16) * 2", "log(100, 10)" + Returns: + On success: {"result": } + On error: {"error": } + + Notes: + - Use 'x' as the variable (e.g., x**2, not x²) + - Multiplication must be explicitly indicated with * (e.g., 2*x, not 2x) + - Powers are represented with ** (e.g., x**2, not x^2) + """ + try: + result = eval( + expression, + {"__builtins__": {}}, + ALLOWED_FUNCTIONS, + ) + return { + "status": "success", + "result": result + } + except Exception as e: + return { + "status": "error", + "error_message": str(e) + }