diff --git a/src/agents/result.py b/src/agents/result.py index 5e27634f7..5ee072eac 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -480,7 +480,10 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: try: while True: self._check_errors() - if self._stored_exception: + should_drain_queued_events = isinstance(self._stored_exception, MaxTurnsExceeded) + if self._stored_exception and ( + not should_drain_queued_events or self._event_queue.empty() + ): logger.debug("Breaking due to stored exception") self.is_complete = True break diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 10004e88f..ae234a706 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1325,6 +1325,56 @@ async def test_tool() -> str: pass +@pytest.mark.asyncio +async def test_streaming_max_turns_emits_pending_tool_output_events() -> None: + async def test_tool() -> str: + return "tool_result" + + tool = function_tool(test_tool, name_override="test_tool") + model, agent = make_model_and_agent(name="test", tools=[tool]) + + queue_function_call_and_text( + model, + get_function_tool_call("test_tool", json.dumps({})), + followup=[get_text_message("done")], + ) + + result = Runner.run_streamed(agent, input="Use test_tool", max_turns=1) + streamed_item_types: list[str] = [] + + with pytest.raises(MaxTurnsExceeded): + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + streamed_item_types.append(event.item.type) + + assert "tool_call_item" in streamed_item_types + assert "tool_call_output_item" in streamed_item_types + + +@pytest.mark.asyncio +async def test_streaming_non_max_turns_exception_does_not_emit_queued_events() -> None: + model, agent = make_model_and_agent(name="test") + model.set_next_output([get_text_message("done")]) + + result = Runner.run_streamed(agent, input="hello") + result.cancel() + await asyncio.sleep(0) + + while not result._event_queue.empty(): + result._event_queue.get_nowait() + result._event_queue.task_done() + + result._stored_exception = RuntimeError("guardrail-triggered") + result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=agent)) + + streamed_events: list[StreamEvent] = [] + with pytest.raises(RuntimeError, match="guardrail-triggered"): + async for event in result.stream_events(): + streamed_events.append(event) + + assert streamed_events == [] + + @pytest.mark.asyncio async def test_streaming_hitl_server_conversation_tracker_priming(): """Test that resuming streaming run from RunState primes server conversation tracker."""