Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from agent_framework import Agent, SupportsAgentRun
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
from agent_framework._sessions import AgentSession
from agent_framework._sessions import AgentSession, BaseContextProvider, BaseHistoryProvider, InMemoryHistoryProvider
from agent_framework._tools import FunctionTool, tool
from agent_framework._types import AgentResponse, AgentResponseUpdate, Content, Message
from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse
Expand Down Expand Up @@ -390,12 +390,19 @@ def _clone_chat_agent(self, agent: Agent) -> Agent:
"user": options.get("user"),
}

# Handoff workflows manage full conversation state via _full_conversation.
# Suppress history providers to prevent duplicate messages on approval resume.
context_providers: list[BaseContextProvider] = [
p for p in agent.context_providers if not isinstance(p, BaseHistoryProvider)
]
context_providers.append(InMemoryHistoryProvider(load_messages=False, store_inputs=False, store_outputs=False))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like this, because it means you cannot have a Agent that stores conversations that were passed to it, to disk, could you not filter and check if there are any History Providers that have load_messages=True and then raise?


return Agent(
client=agent.client,
id=agent.id,
name=agent.name,
description=agent.description,
context_providers=agent.context_providers,
context_providers=context_providers,
middleware=middleware,
default_options=cloned_options, # type: ignore[arg-type]
)
Expand Down
94 changes: 93 additions & 1 deletion python/packages/orchestrations/tests/test_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,99 @@ async def _get() -> ChatResponse:
assert client.resume_validated is True


async def test_handoff_replay_serializes_handoff_function_results() -> None:
async def test_handoff_tool_approval_does_not_duplicate_tool_calls_messages() -> None:
"""InMemoryHistoryProvider must not cause duplicate tool_calls on approval resume (#4411)."""

@tool(name="submit_refund", approval_mode="always_require")
def submit_refund() -> str:
return "ok"

class DuplicateDetectingClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
def __init__(self) -> None:
ChatMiddlewareLayer.__init__(self)
FunctionInvocationLayer.__init__(self)
BaseChatClient.__init__(self)
self._call_index = 0
self.resume_validated = False

def _inner_get_response(
self,
*,
messages: Sequence[Message],
stream: bool,
options: Mapping[str, Any],
**kwargs: Any,
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
del options
del kwargs

if self._call_index == 0:
contents = [
Content.from_function_call(
call_id="refund-call-1",
name="submit_refund",
arguments={},
)
]
else:
# Each assistant message with tool_calls must have a matching tool response.
# Duplicate tool_calls without responses trigger an OpenAI 400 error.
tool_call_ids: list[str] = []
tool_result_ids: set[str] = set()
for msg in messages:
for content in msg.contents:
if content.type == "function_call" and content.call_id:
tool_call_ids.append(content.call_id)
elif content.type == "function_result" and content.call_id:
tool_result_ids.add(content.call_id)
unmatched = [cid for cid in tool_call_ids if cid not in tool_result_ids]
if unmatched:
raise AssertionError(
f"Assistant tool_calls without matching tool response: {unmatched}. "
"This would cause a 400 error from the OpenAI Chat Completions API."
)
self.resume_validated = True
contents = [Content.from_text(text="Refund submitted.")]

self._call_index += 1

if stream:

async def _stream() -> AsyncIterable[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=contents, role="assistant", finish_reason="stop")

return ResponseStream(_stream(), finalizer=lambda updates: ChatResponse.from_updates(updates))

async def _get() -> ChatResponse:
return ChatResponse(
messages=[Message(role="assistant", contents=contents)],
response_id="dup-detect",
)

return _get()

client = DuplicateDetectingClient()
agent = Agent(
id="refund_agent",
name="refund_agent",
client=client,
tools=[submit_refund],
)
workflow = (
HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build()
)

first_events = await _drain(workflow.run("Refund order 123", stream=True))
approval_requests = [
event for event in first_events if event.type == "request_info" and isinstance(event.data, Content)
]
assert approval_requests
first_request = approval_requests[0]

approval_response = first_request.data.to_function_approval_response(True)
await _drain(workflow.run(stream=True, responses={first_request.request_id: approval_response}))

assert client.resume_validated is True
"""Returning to the same agent must not replay dict tool outputs."""

class ReplaySafeHandoffClient(ChatMiddlewareLayer[Any], FunctionInvocationLayer[Any], BaseChatClient[Any]):
Expand Down