diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index cd2dc7bfc7..a0c998757c 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -454,6 +454,7 @@ def as_tool( stream_callback: Callable[[AgentResponseUpdate], None] | Callable[[AgentResponseUpdate], Awaitable[None]] | None = None, + propagate_session: bool = False, ) -> FunctionTool: """Create a FunctionTool that wraps this agent. @@ -464,6 +465,12 @@ def as_tool( arg_description: The description for the function argument. If None, defaults to "Task for {tool_name}". stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). + propagate_session: If True, the parent agent's ``AgentSession`` is + forwarded to this sub-agent's ``run()`` call, so both agents + operate within the same logical session (sharing the same + ``session_id`` and provider-managed state, such as any stored + conversation history or metadata). Defaults to False, meaning + the sub-agent runs with a new, independent session. Returns: A FunctionTool that can be used as a tool by other agents. @@ -480,9 +487,12 @@ def as_tool( # Create an agent agent = Agent(client=client, name="research-agent", description="Performs research tasks") - # Convert the agent to a tool + # Convert the agent to a tool (independent session) research_tool = agent.as_tool() + # Convert the agent to a tool (shared session with parent) + research_tool = agent.as_tool(propagate_session=True) + # Use the tool with another agent coordinator = Agent(client=client, name="coordinator", tools=research_tool) """ @@ -509,16 +519,21 @@ async def agent_wrapper(**kwargs: Any) -> str: # Extract the input from kwargs using the specified arg_name input_text = kwargs.get(arg_name, "") - # Forward runtime context kwargs, excluding arg_name and conversation_id. - forwarded_kwargs = {k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options")} + # Extract parent session when propagate_session is enabled + parent_session = kwargs.get("session") if propagate_session else None + + # Forward runtime context kwargs, excluding framework-internal keys. + forwarded_kwargs = { + k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session") + } if stream_callback is None: # Use non-streaming mode - return (await self.run(input_text, stream=False, **forwarded_kwargs)).text + return (await self.run(input_text, stream=False, session=parent_session, **forwarded_kwargs)).text # Use streaming mode - accumulate updates and create final response response_updates: list[AgentResponseUpdate] = [] - async for update in self.run(input_text, stream=True, **forwarded_kwargs): + async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs): response_updates.append(update) if is_async_callback: await stream_callback(update) # type: ignore[misc] @@ -1061,6 +1076,9 @@ async def _prepare_run_context( # in function middleware context and tool invocation. existing_additional_args = opts.pop("additional_function_arguments", None) or {} additional_function_arguments = {**kwargs, **existing_additional_args} + # Include session so as_tool() wrappers with propagate_session=True can access it. + if active_session is not None: + additional_function_arguments["session"] = active_session # Build options dict from run() options merged with provided options run_opts: dict[str, Any] = { diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index c8d2d9bf8b..d41b87b707 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -707,6 +707,81 @@ async def test_chat_agent_as_tool_name_sanitization(client: SupportsChatGetRespo assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}" +async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None: + """Test that propagate_session=True forwards the parent's session to the sub-agent.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool(propagate_session=True) + + parent_session = AgentSession(session_id="parent-session-123") + parent_session.state["shared_key"] = "shared_value" + + # Spy on the agent's run method to capture the session argument + original_run = agent.run + captured_session = None + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_session + captured_session = kwargs.get("session") + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + + await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + + assert captured_session is parent_session + assert captured_session.session_id == "parent-session-123" + assert captured_session.state["shared_key"] == "shared_value" + + +async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None: + """Test that propagate_session defaults to False and does not forward the session.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool() # default: propagate_session=False + + parent_session = AgentSession(session_id="parent-session-456") + + original_run = agent.run + captured_session = None + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_session + captured_session = kwargs.get("session") + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + + await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + + assert captured_session is None + + +async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None: + """Test that shared session allows the sub-agent to read and write parent's state.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool(propagate_session=True) + + parent_session = AgentSession(session_id="shared-session") + parent_session.state["counter"] = 0 + + # The sub-agent receives the same session object, so mutations are shared + original_run = agent.run + captured_session = None + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_session + captured_session = kwargs.get("session") + if captured_session: + captured_session.state["counter"] += 1 + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + + await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + + # The parent's state should reflect the sub-agent's mutation + assert parent_session.state["counter"] == 1 + + async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None: """Test basic as_mcp_server functionality.""" agent = Agent(client=client, name="TestAgent", description="Test agent for MCP") diff --git a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py new file mode 100644 index 0000000000..33748437e0 --- /dev/null +++ b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable + +from agent_framework import AgentContext, AgentSession +from agent_framework.openai import OpenAIResponsesClient +from dotenv import load_dotenv + +load_dotenv() + +""" +Agent-as-Tool: Session Propagation Example + +Demonstrates how to share an AgentSession between a coordinator agent and a +sub-agent invoked as a tool using ``propagate_session=True``. + +When session propagation is enabled, both agents share the same session object, +including session_id and the mutable state dict. This allows correlated +conversation tracking and shared state across the agent hierarchy. + +The middleware functions below are purely for observability — they are NOT +required for session propagation to work. +""" + + +async def log_session( + context: AgentContext, + call_next: Callable[[], Awaitable[None]], +) -> None: + """Agent middleware that logs the session received by each agent. + + NOT required for session propagation — only used to observe the flow. + If propagation is working, both agents will show the same session_id. + """ + session: AgentSession | None = context.session + agent_name = context.agent.name or "unknown" + session_id = session.session_id if session else None + state = dict(session.state) if session else {} + print(f" [{agent_name}] session_id={session_id}, state={state}") + await call_next() + + +async def main() -> None: + print("=== Agent-as-Tool: Session Propagation ===\n") + + client = OpenAIResponsesClient() + + # --- Sub-agent: a research specialist --- + # The sub-agent has the same log_session middleware to prove it receives the session. + research_agent = client.as_agent( + name="ResearchAgent", + instructions="You are a research assistant. Provide concise answers.", + middleware=[log_session], + ) + + # propagate_session=True: the coordinator's session will be forwarded + research_tool = research_agent.as_tool( + name="research", + description="Research a topic and return findings", + arg_name="query", + arg_description="The research query", + propagate_session=True, + ) + + # --- Coordinator agent --- + coordinator = client.as_agent( + name="CoordinatorAgent", + instructions="You coordinate research. Use the 'research' tool to look up information.", + tools=[research_tool], + middleware=[log_session], + ) + + # Create a shared session and put some state in it + session = coordinator.create_session() + session.state["request_source"] = "demo" + print(f"Session ID: {session.session_id}") + print(f"Session state before run: {session.state}\n") + + query = "What are the latest developments in quantum computing?" + print(f"User: {query}\n") + + result = await coordinator.run(query, session=session) + + print(f"\nCoordinator: {result}\n") + print(f"Session state after run: {session.state}") + print( + "\nIf both agents show the same session_id above, session propagation is working." + ) + + +if __name__ == "__main__": + asyncio.run(main())