Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ class MCPConfig(TypedDict):
default_tool_error_function.
"""

include_server_in_tool_names: NotRequired[bool]
"""If True, MCP tools are exposed as `<server_name>_<tool_name>` to avoid collisions across
servers that publish the same tool names. Defaults to False.
"""


@dataclass
class AgentBase(Generic[TContext]):
Expand Down Expand Up @@ -182,12 +187,14 @@ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[
failure_error_function = self.mcp_config.get(
"failure_error_function", default_tool_error_function
)
include_server_in_tool_names = self.mcp_config.get("include_server_in_tool_names", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers,
convert_schemas_to_strict,
run_context,
self,
failure_error_function=failure_error_function,
include_server_in_tool_names=include_server_in_tool_names,
)

async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
Expand Down
54 changes: 40 additions & 14 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ async def get_all_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
include_server_in_tool_names: bool = False,
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
Expand All @@ -189,6 +190,7 @@ async def get_all_function_tools(
run_context,
agent,
failure_error_function=failure_error_function,
include_server_in_tool_names=include_server_in_tool_names,
)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
Expand All @@ -209,24 +211,39 @@ async def get_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
include_server_in_tool_names: bool = False,
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

with mcp_tools_span(server=server.name) as span:
tools = await server.list_tools(run_context, agent)
span.span_data.result = [tool.name for tool in tools]

tool_name_prefix = (
cls._server_tool_name_prefix(server.name) if include_server_in_tool_names else ""
)
return [
cls.to_function_tool(
tool,
server,
convert_schemas_to_strict,
agent,
failure_error_function=failure_error_function,
tool_name_override=f"{tool_name_prefix}{tool.name}" if tool_name_prefix else None,
)
for tool in tools
]

@staticmethod
def _server_tool_name_prefix(server_name: str) -> str:
normalized = "".join(
char if char.isalnum() or char in ("_", "-") else "_" for char in server_name
)
normalized = normalized.strip("_-")
if not normalized:
normalized = "server"
return f"{normalized}_"

@classmethod
def to_function_tool(
cls,
Expand All @@ -235,6 +252,7 @@ def to_function_tool(
convert_schemas_to_strict: bool,
agent: AgentBase | None = None,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
tool_name_override: str | None = None,
) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool.

Expand All @@ -243,7 +261,10 @@ def to_function_tool(
When omitted, this helper preserves the historical behavior and leaves
``needs_approval`` disabled.
"""
invoke_func_impl = functools.partial(cls.invoke_mcp_tool, server, tool)
tool_name = tool_name_override or tool.name
invoke_func_impl = functools.partial(
cls.invoke_mcp_tool, server, tool, tool_display_name=tool_name
)
effective_failure_error_function = server._get_failure_error_function(
failure_error_function
)
Expand Down Expand Up @@ -280,18 +301,18 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
SpanError(
message="Error running tool (non-fatal)",
data={
"tool_name": tool.name,
"tool_name": tool_name,
"error": str(e),
},
)
)

# Log the error.
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} failed")
logger.debug(f"MCP tool {tool_name} failed")
else:
logger.error(
f"MCP tool {tool.name} failed: {input_json} {e}",
f"MCP tool {tool_name} failed: {input_json} {e}",
exc_info=e,
)

Expand All @@ -302,7 +323,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput:
) = server._get_needs_approval_for_tool(tool, agent)

return FunctionTool(
name=tool.name,
name=tool_name,
description=tool.description or "",
params_json_schema=schema,
on_invoke_tool=invoke_func,
Expand Down Expand Up @@ -361,25 +382,30 @@ async def invoke_mcp_tool(
input_json: str,
*,
meta: dict[str, Any] | None = None,
tool_display_name: str | None = None,
) -> ToolOutput:
"""Invoke an MCP tool and return the result as ToolOutput."""
tool_name = tool_display_name or tool.name
try:
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
except Exception as e:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool.name}")
logger.debug(f"Invalid JSON input for tool {tool_name}")
else:
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
logger.debug(f"Invalid JSON input for tool {tool_name}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool.name}: {input_json}"
f"Invalid JSON input for tool {tool_name}: {input_json}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking MCP tool {tool.name}")
logger.debug(f"Invoking MCP tool {tool_name}")
else:
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
logger.debug(f"Invoking MCP tool {tool_name} with input {input_json}")

try:
# Meta resolvers should receive the canonical MCP tool identifier.
# The display name may be prefixed for collision avoidance, but call_tool
# and downstream resolver routing still key off the original tool name.
resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data)
merged_meta = cls._merge_mcp_meta(resolved_meta, meta)
if merged_meta is None:
Expand All @@ -390,15 +416,15 @@ async def invoke_mcp_tool(
# Re-raise UserError as-is (it already has a good message)
raise
except Exception as e:
logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}")
logger.error(f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}")
raise AgentsException(
f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}"
f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
logger.debug(f"MCP tool {tool_name} completed.")
else:
logger.debug(f"MCP tool {tool.name} returned {result}")
logger.debug(f"MCP tool {tool_name} returned {result}")

# If structured content is requested and available, use it exclusively
tool_output: ToolOutput
Expand Down
169 changes: 168 additions & 1 deletion tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel, TypeAdapter

from agents import Agent, FunctionTool, RunContextWrapper, default_tool_error_function
from agents.exceptions import AgentsException, ModelBehaviorError
from agents.exceptions import AgentsException, ModelBehaviorError, UserError
from agents.mcp import MCPServer, MCPUtil
from agents.tool_context import ToolContext

Expand Down Expand Up @@ -80,6 +80,97 @@ async def test_get_all_function_tools():
assert all(tool.name in names for tool in tools)


@pytest.mark.asyncio
async def test_get_all_function_tools_duplicate_names_raise_by_default():
server1 = FakeMCPServer(server_name="github")
server1.add_tool("create_issue", {})

server2 = FakeMCPServer(server_name="linear")
server2.add_tool("create_issue", {})

run_context = RunContextWrapper(context=None)
agent = Agent(name="test_agent", instructions="Test agent")

with pytest.raises(UserError, match="Duplicate tool names found across MCP servers"):
await MCPUtil.get_all_function_tools([server1, server2], False, run_context, agent)


@pytest.mark.asyncio
async def test_get_all_function_tools_can_prefix_with_server_name():
server1 = FakeMCPServer(server_name="GitHub MCP Server")
server1.add_tool("create_issue", {})

server2 = FakeMCPServer(server_name="linear")
server2.add_tool("create_issue", {})

run_context = RunContextWrapper(context=None)
agent = Agent(
name="test_agent",
instructions="Test agent",
mcp_servers=[server1, server2],
mcp_config={"include_server_in_tool_names": True},
)

tools = await agent.get_mcp_tools(run_context)
tool_names = {tool.name for tool in tools}
assert tool_names == {"GitHub_MCP_Server_create_issue", "linear_create_issue"}

github_tool = next(tool for tool in tools if tool.name == "GitHub_MCP_Server_create_issue")
linear_tool = next(tool for tool in tools if tool.name == "linear_create_issue")
assert isinstance(github_tool, FunctionTool)
assert isinstance(linear_tool, FunctionTool)

github_ctx = ToolContext(
context=None,
tool_name=github_tool.name,
tool_call_id="prefixed_call_1",
tool_arguments='{"title":"a"}',
)
linear_ctx = ToolContext(
context=None,
tool_name=linear_tool.name,
tool_call_id="prefixed_call_2",
tool_arguments='{"title":"b"}',
)

github_result = await github_tool.on_invoke_tool(github_ctx, '{"title":"a"}')
linear_result = await linear_tool.on_invoke_tool(linear_ctx, '{"title":"b"}')
assert isinstance(github_result, dict)
assert isinstance(linear_result, dict)
assert server1.tool_calls == ["create_issue"]
assert server2.tool_calls == ["create_issue"]


@pytest.mark.asyncio
async def test_get_all_function_tools_prefix_falls_back_for_empty_server_name_slug():
server = FakeMCPServer(server_name="!!!")
server.add_tool("search", {})

run_context = RunContextWrapper(context=None)
agent = Agent(
name="test_agent",
instructions="Test agent",
mcp_servers=[server],
mcp_config={"include_server_in_tool_names": True},
)

tools = await agent.get_mcp_tools(run_context)
assert len(tools) == 1
prefixed_tool = tools[0]
assert isinstance(prefixed_tool, FunctionTool)
assert prefixed_tool.name == "server_search"

tool_context = ToolContext(
context=None,
tool_name=prefixed_tool.name,
tool_call_id="prefixed_call_3",
tool_arguments='{"query":"docs"}',
)
result = await prefixed_tool.on_invoke_tool(tool_context, '{"query":"docs"}')
assert isinstance(result, dict)
assert server.tool_calls == ["search"]


@pytest.mark.asyncio
async def test_invoke_mcp_tool():
"""Test that the invoke_mcp_tool function invokes an MCP tool and returns the result."""
Expand Down Expand Up @@ -125,6 +216,48 @@ def resolve_meta(context):
assert captured["arguments"] == {}


@pytest.mark.asyncio
async def test_mcp_meta_resolver_uses_original_tool_name_with_prefixed_display_name():
captured: dict[str, Any] = {}

def resolve_meta(context):
captured["tool_name"] = context.tool_name
return {"scope": "meta"}

server = FakeMCPServer(
server_name="GitHub MCP Server",
tool_meta_resolver=resolve_meta,
)
server.add_tool("create_issue", {})

run_context = RunContextWrapper(context=None)
agent = Agent(
name="test_agent",
instructions="Test agent",
mcp_servers=[server],
mcp_config={"include_server_in_tool_names": True},
)

tools = await agent.get_mcp_tools(run_context)
assert len(tools) == 1

prefixed_tool = tools[0]
assert isinstance(prefixed_tool, FunctionTool)
assert prefixed_tool.name == "GitHub_MCP_Server_create_issue"

tool_context = ToolContext(
context=None,
tool_name=prefixed_tool.name,
tool_call_id="prefixed_call_meta_1",
tool_arguments='{"title":"a"}',
)
await prefixed_tool.on_invoke_tool(tool_context, '{"title":"a"}')

assert captured["tool_name"] == "create_issue"
assert server.tool_calls == ["create_issue"]
assert server.tool_metas[-1] == {"scope": "meta"}


@pytest.mark.asyncio
async def test_mcp_meta_resolver_does_not_mutate_arguments():
def resolve_meta(context):
Expand Down Expand Up @@ -290,6 +423,40 @@ async def call_tool(
assert "Timed out" in result


@pytest.mark.asyncio
async def test_mcp_tool_failure_logs_prefixed_name_when_tool_data_logging_enabled(
caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch
):
import agents._debug as debug_settings

caplog.set_level(logging.ERROR)
monkeypatch.setattr(debug_settings, "DONT_LOG_TOOL_DATA", False)

server = CrashingFakeMCPServer()
server.add_tool("crashing_tool", {})

mcp_tool = MCPTool(name="crashing_tool", inputSchema={})
agent = Agent(name="test-agent")
function_tool = MCPUtil.to_function_tool(
mcp_tool,
server,
convert_schemas_to_strict=False,
agent=agent,
tool_name_override="prefixed_crashing_tool",
)

tool_context = ToolContext(
context=None,
tool_name="prefixed_crashing_tool",
tool_call_id="test_call_prefixed_log",
tool_arguments="{}",
)
result = await function_tool.on_invoke_tool(tool_context, "{}")

assert isinstance(result, str)
assert "MCP tool prefixed_crashing_tool failed" in caplog.text


@pytest.mark.asyncio
async def test_to_function_tool_legacy_call_without_agent_uses_server_policy():
"""Legacy three-argument to_function_tool calls should honor server policy."""
Expand Down