diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index d2ff5af959..bc2ccc47b7 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -39,7 +39,7 @@ from agent_framework import Agent, SupportsAgentRun from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware -from agent_framework._sessions import AgentSession +from agent_framework._sessions import AgentSession, BaseContextProvider, BaseHistoryProvider, InMemoryHistoryProvider from agent_framework._tools import FunctionTool, tool from agent_framework._types import AgentResponse, AgentResponseUpdate, Content, Message from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse @@ -390,12 +390,19 @@ def _clone_chat_agent(self, agent: Agent) -> Agent: "user": options.get("user"), } + # Handoff workflows manage full conversation state via _full_conversation. + # Suppress history providers to prevent duplicate messages on approval resume. + context_providers: list[BaseContextProvider] = [ + p for p in agent.context_providers if not isinstance(p, BaseHistoryProvider) + ] + context_providers.append(InMemoryHistoryProvider(load_messages=False, store_inputs=False, store_outputs=False)) + return Agent( client=agent.client, id=agent.id, name=agent.name, description=agent.description, - context_providers=agent.context_providers, + context_providers=context_providers, middleware=middleware, default_options=cloned_options, # type: ignore[arg-type] ) diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index e0d94355b6..a064f5c935 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -472,7 +472,99 @@ async def _get() -> ChatResponse: assert client.resume_validated is True -async def test_handoff_replay_serializes_handoff_function_results() -> None: +async def test_handoff_tool_approval_does_not_duplicate_tool_calls_messages() -> None: + """InMemoryHistoryProvider must not cause duplicate tool_calls on approval resume (#4411).""" + + @tool(name="submit_refund", approval_mode="always_require") + def submit_refund() -> str: + return "ok" + + class DuplicateDetectingClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]): + def __init__(self) -> None: + ChatMiddlewareLayer.__init__(self) + FunctionInvocationLayer.__init__(self) + BaseChatClient.__init__(self) + self._call_index = 0 + self.resume_validated = False + + def _inner_get_response( + self, + *, + messages: Sequence[Message], + stream: bool, + options: Mapping[str, Any], + **kwargs: Any, + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + del options + del kwargs + + if self._call_index == 0: + contents = [ + Content.from_function_call( + call_id="refund-call-1", + name="submit_refund", + arguments={}, + ) + ] + else: + # Each assistant message with tool_calls must have a matching tool response. + # Duplicate tool_calls without responses trigger an OpenAI 400 error. + tool_call_ids: list[str] = [] + tool_result_ids: set[str] = set() + for msg in messages: + for content in msg.contents: + if content.type == "function_call" and content.call_id: + tool_call_ids.append(content.call_id) + elif content.type == "function_result" and content.call_id: + tool_result_ids.add(content.call_id) + unmatched = [cid for cid in tool_call_ids if cid not in tool_result_ids] + if unmatched: + raise AssertionError( + f"Assistant tool_calls without matching tool response: {unmatched}. " + "This would cause a 400 error from the OpenAI Chat Completions API." + ) + self.resume_validated = True + contents = [Content.from_text(text="Refund submitted.")] + + self._call_index += 1 + + if stream: + + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop") + + return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates)) + + async def _get() -> ChatResponse: + return ChatResponse( + messages=[Message(role="assistant", contents=contents)], + response_id="dup-detect", + ) + + return _get() + + client = DuplicateDetectingClient() + agent = Agent( + id="refund_agent", + name="refund_agent", + client=client, + tools=[submit_refund], + ) + workflow = ( + HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build() + ) + + first_events = await _drain(workflow.run("Refund order 123", stream=True)) + approval_requests = [ + event for event in first_events if event.type == "request_info" and isinstance(event.data, Content) + ] + assert approval_requests + first_request = approval_requests[0] + + approval_response = first_request.data.to_function_approval_response(True) + await _drain(workflow.run(stream=True, responses={first_request.request_id: approval_response})) + + assert client.resume_validated is True """Returning to the same agent must not replay dict tool outputs.""" class ReplaySafeHandoffClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):