diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index d7c758be7..3ebcb5794 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -638,6 +638,7 @@ async def get_prompt( async def cleanup(self): """Cleanup the server.""" async with self._cleanup_lock: + cleanup_cancelled = False # Only raise HTTP errors if we're cleaning up after a failed connection. # During normal teardown (via __aexit__), log but don't raise to avoid # masking the original exception. @@ -646,9 +647,26 @@ async def cleanup(self): try: await self.exit_stack.aclose() except asyncio.CancelledError as e: + cleanup_cancelled = True logger.debug(f"Cleanup cancelled for MCP server '{self.name}': {e}") raise except BaseExceptionGroup as eg: + + def contains_cancelled_error(exc: BaseException) -> bool: + if isinstance(exc, asyncio.CancelledError): + return True + if isinstance(exc, BaseExceptionGroup): + return any(contains_cancelled_error(inner) for inner in exc.exceptions) + return False + + if contains_cancelled_error(eg): + cleanup_cancelled = True + logger.debug( + "Cleanup cancelled for MCP server " + f"'{self.name}' with grouped exception: {eg}" + ) + raise + # Extract HTTP errors from ExceptionGroup raised during cleanup # This happens when background tasks fail (e.g., HTTP errors) http_error = None @@ -709,7 +727,12 @@ async def cleanup(self): else: logger.error(f"Error cleaning up server: {e}") finally: - self.session = None + if not cleanup_cancelled: + # Reset stack state only after a completed cleanup. If cleanup is cancelled, + # keep the existing stack so a follow-up cleanup can finish unwinding it. + self.exit_stack = AsyncExitStack() + self.session = None + self.server_initialize_result = None class MCPServerStdioParams(TypedDict): diff --git a/src/agents/result.py b/src/agents/result.py index 5e27634f7..365d00079 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -8,6 +8,9 @@ from dataclasses import InitVar, dataclass, field from typing import Any, Literal, TypeVar, cast +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + from .agent import Agent from .agent_output import AgentOutputSchemaBase from .exceptions import ( @@ -124,6 +127,16 @@ class RunResultBase(abc.ABC): _trace_state: TraceState | None = field(default=None, init=False, repr=False) """Serialized trace metadata captured during the run.""" + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + # RunResult objects are runtime values; schema generation should treat them as instances + # instead of recursively traversing internal dataclass annotations. + return core_schema.is_instance_schema(cls) + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: diff --git a/tests/mcp/test_connect_disconnect.py b/tests/mcp/test_connect_disconnect.py index b00130397..0c5dcbf33 100644 --- a/tests/mcp/test_connect_disconnect.py +++ b/tests/mcp/test_connect_disconnect.py @@ -1,3 +1,5 @@ +import asyncio +import sys from unittest.mock import AsyncMock, patch import pytest @@ -7,6 +9,46 @@ from .helpers import DummyStreamsContextManager, tee +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup # pyright: ignore[reportMissingImports] +else: + from builtins import BaseExceptionGroup + + +class CountingStreamsContextManager: + def __init__(self, counter: dict[str, int]): + self.counter = counter + + async def __aenter__(self): + self.counter["enter"] += 1 + return (object(), object()) + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.counter["exit"] += 1 + + +class CancelThenCloseExitStack: + def __init__(self): + self.close_calls = 0 + + async def aclose(self): + self.close_calls += 1 + if self.close_calls == 1: + raise asyncio.CancelledError("first cleanup interrupted") + + +class CancelGroupThenCloseExitStack: + def __init__(self): + self.close_calls = 0 + + async def aclose(self): + self.close_calls += 1 + if self.close_calls == 1: + raise BaseExceptionGroup( + "grouped cancellation during cleanup", + [asyncio.CancelledError("grouped cleanup interruption")], + ) + @pytest.mark.asyncio @patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) @@ -67,3 +109,91 @@ async def test_manual_connect_disconnect_works( await server.cleanup() assert server.session is None, "Server should be disconnected" + + +@pytest.mark.asyncio +@patch("agents.mcp.server.ClientSession.initialize", new_callable=AsyncMock, return_value=None) +@patch("agents.mcp.server.stdio_client") +async def test_cleanup_resets_exit_stack_and_reconnects( + mock_stdio_client: AsyncMock, mock_initialize: AsyncMock +): + counter = {"enter": 0, "exit": 0} + mock_stdio_client.side_effect = lambda params: CountingStreamsContextManager(counter) + + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + + await server.connect() + original_exit_stack = server.exit_stack + + await server.cleanup() + assert server.session is None + assert server.exit_stack is not original_exit_stack + assert server.server_initialize_result is None + assert counter == {"enter": 1, "exit": 1} + + await server.connect() + await server.cleanup() + assert counter == {"enter": 2, "exit": 2} + + +@pytest.mark.asyncio +async def test_cleanup_cancellation_preserves_exit_stack_for_retry(): + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + cancelled_exit_stack = CancelThenCloseExitStack() + + server.exit_stack = cancelled_exit_stack # type: ignore[assignment] + server.session = object() # type: ignore[assignment] + server.server_initialize_result = object() # type: ignore[assignment] + + with pytest.raises(asyncio.CancelledError): + await server.cleanup() + + assert id(server.exit_stack) == id(cancelled_exit_stack) + assert server.session is not None + assert server.server_initialize_result is not None + + await server.cleanup() + + assert cancelled_exit_stack.close_calls == 2 + assert id(server.exit_stack) != id(cancelled_exit_stack) + assert server.session is None + assert server.server_initialize_result is None + + +@pytest.mark.asyncio +async def test_cleanup_grouped_cancellation_preserves_exit_stack_for_retry(): + server = MCPServerStdio( + params={ + "command": tee, + }, + cache_tools_list=True, + ) + cancelled_exit_stack = CancelGroupThenCloseExitStack() + + server.exit_stack = cancelled_exit_stack # type: ignore[assignment] + server.session = object() # type: ignore[assignment] + server.server_initialize_result = object() # type: ignore[assignment] + + with pytest.raises(BaseExceptionGroup): + await server.cleanup() + + assert id(server.exit_stack) == id(cancelled_exit_stack) + assert server.session is not None + assert server.server_initialize_result is not None + + await server.cleanup() + + assert cancelled_exit_stack.close_calls == 2 + assert id(server.exit_stack) != id(cancelled_exit_stack) + assert server.session is None + assert server.server_initialize_result is None diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index 63e4d2e8f..8cbbe038b 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -7,7 +7,7 @@ import pytest from openai.types.responses import ResponseOutputMessage, ResponseOutputText -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from agents import ( Agent, @@ -45,6 +45,16 @@ class Foo(BaseModel): bar: int +def test_run_result_streaming_supports_pydantic_model_rebuild() -> None: + class StreamingRunContainer(BaseModel): + query_id: str + run_stream: RunResultStreaming | None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + StreamingRunContainer.model_rebuild() + + def _create_message(text: str) -> ResponseOutputMessage: return ResponseOutputMessage( id="msg",