From 52a9d9466f101491cd6adf9f0219ffb2e2354a73 Mon Sep 17 00:00:00 2001 From: askdevai-bot Date: Mon, 26 May 2025 00:41:21 -0400 Subject: [PATCH] fix: properly handle multiple MCPClient closes without cancel-scope errors --- tinyagent/mcp_client.py | 185 ++++++++++++++-------------------------- 1 file changed, 62 insertions(+), 123 deletions(-) diff --git a/tinyagent/mcp_client.py b/tinyagent/mcp_client.py index eb1919f..b83186d 100644 --- a/tinyagent/mcp_client.py +++ b/tinyagent/mcp_client.py @@ -1,11 +1,7 @@ import asyncio import json import logging -from typing import Dict, List, Optional, Any, Tuple, Callable - -# Keep your MCPClient implementation unchanged -import asyncio -from contextlib import AsyncExitStack +from typing import Callable, Dict, List, Optional, Any # MCP core imports from mcp import ClientSession, StdioServerParameters @@ -15,148 +11,91 @@ logger = logging.getLogger(__name__) class MCPClient: + """ + Asynchronous client for MCP servers that supports multiple concurrent instances without cancel-scope errors. + + Usage: + client = MCPClient() + await client.connect(...) + # use client.call_tool, etc. + await client.close() + """ + def __init__(self, logger: Optional[logging.Logger] = None): - self.session = None - self.exit_stack = AsyncExitStack() self.logger = logger or logging.getLogger(__name__) - - # Simplified callback system - self.callbacks: List[callable] = [] - + self.stdio = None + self.sock_write = None + self.session = None + self.callbacks: List[Callable] = [] self.logger.debug("MCPClient initialized") - def add_callback(self, callback: callable) -> None: - """ - Add a callback function to the client. - - Args: - callback: A function that accepts (event_name, client, **kwargs) - """ + def add_callback(self, callback: Callable) -> None: + """Register a callback(event_name, client, **kwargs)""" self.callbacks.append(callback) - + async def _run_callbacks(self, event_name: str, **kwargs) -> None: - """ - Run all registered callbacks for an event. - - Args: - event_name: The name of the event - **kwargs: Additional data for the event - """ - for callback in self.callbacks: + for cb in self.callbacks: try: - logger.debug(f"Running callback: {callback}") - if asyncio.iscoroutinefunction(callback): - logger.debug(f"Callback is a coroutine function") - await callback(event_name, self, **kwargs) + if asyncio.iscoroutinefunction(cb): + await cb(event_name, self, **kwargs) + elif hasattr(cb, '__call__') and asyncio.iscoroutinefunction(cb.__call__): + await cb(event_name, self, **kwargs) else: - # Check if the callback is a class with an async __call__ method - if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__): - logger.debug(f"Callback is a class with an async __call__ method") - await callback(event_name, self, **kwargs) - else: - logger.debug(f"Callback is a regular function") - callback(event_name, self, **kwargs) + cb(event_name, self, **kwargs) except Exception as e: - logger.error(f"Error in callback for {event_name}: {str(e)}") + self.logger.error(f"Error in callback for {event_name}: {e}") - async def connect(self, command: str, args: list[str]): + async def connect(self, command: str, args: List[str]) -> None: """ - Launches the MCP server subprocess and initializes the client session. - :param command: e.g. "python" or "node" - :param args: list of args to pass, e.g. ["my_server.py"] or ["build/index.js"] + Launch MCP server subprocess and initialize client. + :param command: executable, e.g. 'python' + :param args: list of args, e.g. ['-m', 'mcp.examples.echo_server'] """ - # Prepare stdio transport parameters params = StdioServerParameters(command=command, args=args) - # Open the stdio client transport - self.stdio, self.sock_write = await self.exit_stack.enter_async_context( - stdio_client(params) - ) - # Create and initialize the MCP client session - self.session = await self.exit_stack.enter_async_context( - ClientSession(self.stdio, self.sock_write) - ) + # open stdio transport + self.stdio, self.sock_write = await stdio_client(params) + # enter client session context + self.session = await ClientSession(self.stdio, self.sock_write).__aenter__() await self.session.initialize() + self.logger.debug("MCPClient connected to server") - async def list_tools(self): + async def list_tools(self) -> None: resp = await self.session.list_tools() print("Available tools:") for tool in resp.tools: - print(f" • {tool.name}: {tool.description}") + print(f"- {tool.name}: {tool.description}") - async def call_tool(self, name: str, arguments: dict): - """ - Invokes a named tool and returns its raw content list. - """ - # Notify tool start + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: await self._run_callbacks("tool_start", tool_name=name, arguments=arguments) - try: resp = await self.session.call_tool(name, arguments) - - # Notify tool end - await self._run_callbacks("tool_end", tool_name=name, arguments=arguments, - result=resp.content, success=True) - + await self._run_callbacks( + "tool_end", tool_name=name, arguments=arguments, result=resp.content, success=True + ) return resp.content except Exception as e: - # Notify tool end with error - await self._run_callbacks("tool_end", tool_name=name, arguments=arguments, - error=str(e), success=False) + await self._run_callbacks( + "tool_end", tool_name=name, arguments=arguments, error=str(e), success=False + ) raise - async def close(self): - """Clean up subprocess and streams.""" - if self.exit_stack: + async def close(self) -> None: + """Clean up session and subprocess.""" + # exit session context + if self.session: try: - await self.exit_stack.aclose() - except (RuntimeError, asyncio.CancelledError) as e: - # Log the error but don't re-raise it - self.logger.error(f"Error during client cleanup: {e}") - finally: - # Always reset these regardless of success or failure - self.session = None - self.exit_stack = AsyncExitStack() - -async def run_example(): - """Example usage of MCPClient with proper logging.""" - import sys - from tinyagent.hooks.logging_manager import LoggingManager - - # Create and configure logging manager - log_manager = LoggingManager(default_level=logging.INFO) - log_manager.set_levels({ - 'tinyagent.mcp_client': logging.DEBUG, # Debug for this module - 'tinyagent.tiny_agent': logging.INFO, - }) - - # Configure a console handler - console_handler = logging.StreamHandler(sys.stdout) - log_manager.configure_handler( - console_handler, - format_string='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - level=logging.DEBUG - ) - - # Get module-specific logger - mcp_logger = log_manager.get_logger('tinyagent.mcp_client') - - mcp_logger.debug("Starting MCPClient example") - - # Create client with our logger - client = MCPClient(logger=mcp_logger) - - try: - # Connect to a simple echo server - await client.connect("python", ["-m", "mcp.examples.echo_server"]) - - # List available tools - await client.list_tools() - - # Call the echo tool - result = await client.call_tool("echo", {"message": "Hello, MCP!"}) - mcp_logger.info(f"Echo result: {result}") - - finally: - # Clean up - await client.close() - mcp_logger.debug("Example completed") + await self.session.__aexit__(None, None, None) + except Exception as e: + self.logger.error(f"Error closing session: {e}") + # close stdio and sock + for stream in [self.sock_write, self.stdio]: + try: + if hasattr(stream, 'close'): + stream.close() + except Exception as e: + self.logger.error(f"Error closing stream: {e}") + # reset state + self.session = None + self.sock_write = None + self.stdio = None + self.logger.debug("MCPClient closed")