Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class BaseSession(
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_response_routers: list["ResponseRouter"]
_closing: bool = False

def __init__(
self,
Expand Down Expand Up @@ -252,6 +253,9 @@ async def send_request(
Do not use this method to emit notifications! Use send_notification()
instead.
"""
if self._closing:
raise McpError(ErrorData(code=CONNECTION_CLOSED, message="Connection closed"))

request_id = self._request_id
self._request_id = request_id + 1

Expand Down Expand Up @@ -307,7 +311,8 @@ async def send_request(
return result_type.model_validate(response_or_error.result)

finally:
self._response_streams.pop(request_id, None)
self._response_streams.pop(request_id, None) if not self._closing else None

self._progress_callbacks.pop(request_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()
Expand Down Expand Up @@ -444,15 +449,17 @@ async def _receive_loop(self) -> None:
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
# Stream might already be closed
pass
self._response_streams.clear()
self._closing = True
with anyio.CancelScope(shield=True):
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
# Stream might already be closed
pass
self._response_streams.clear()

def _normalize_request_id(self, response_id: RequestId) -> RequestId:
"""
Expand Down Expand Up @@ -508,7 +515,7 @@ async def _handle_response(self, message: SessionMessage) -> None:
return # Handled

# Fall back to normal response streams
stream = self._response_streams.pop(response_id, None)
stream = self._response_streams.pop(response_id, None) if not self._closing else None
if stream: # pragma: no cover
await stream.send(root)
else: # pragma: no cover
Expand Down
71 changes: 71 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,74 @@ async def mock_server():
await ev_closed.wait()
with anyio.fail_after(1):
await ev_response.wait()


@pytest.mark.anyio
async def test_session_aexit_cleanup():
"""Test that the session is closing properly, cleaning up all resources."""
pending_request_ids: list[int | str] = []
requests_received = anyio.Event()
client_session_closed = anyio.Event()

async with (
anyio.create_task_group() as tg,
create_client_server_memory_streams() as (client_streams, server_streams),
):
client_read, client_write = client_streams
server_read, _ = server_streams

async def mock_server():
"""Block responses to simulate a server that does not respond."""
# Wait for two ping requests
for _ in range(2):
message = await server_read.receive()
assert isinstance(message, SessionMessage)
root = message.message.root
assert isinstance(root, JSONRPCRequest)
assert root.method == "ping"
pending_request_ids.append(root.id)

# Signal that both requests have been received
requests_received.set()

# Wait for the client session to be closed
# This ensures the cleanup logic in finally block has time to run
await client_session_closed.wait()

async def send_ping(session: ClientSession):
# Since we are closing the session, "Connection closed" McpError is expected
with pytest.raises(McpError) as e:
await session.send_ping()
assert "Connection closed" in str(e.value)

# Start the mock server in the background
tg.start_soon(mock_server)

# Create a session and send multiple ping requests in background
async with ClientSession(read_stream=client_read, write_stream=client_write) as session:
# Verify initial state
assert len(session._response_streams) == 0

# Start two ping requests in background
tg.start_soon(send_ping, session)
tg.start_soon(send_ping, session)

# Wait for both requests to be sent and received by server
await requests_received.wait()
await anyio.sleep(0.1) # Give time for streams to be created

# Verify we have 2 response streams
assert len(session._response_streams) == 2

# We close the session by escaping the async with block
client_session_closed.set()

# Since the sesssion has been closed, "Connection closed" McpError is expected
with pytest.raises(McpError) as e:
await session.send_ping()
assert "Connection closed" in str(e.value)

# Verify all response streams have been cleaned up
# (This happens when the async with block exits and __aexit__ is called)
assert session is not None
assert len(session._response_streams) == 0