Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2844,6 +2844,7 @@ async def __anext__(self) -> UpdateT:
except StopAsyncIteration:
self._consumed = True
await self._run_cleanup_hooks()
await self.get_final_response()
raise
except Exception:
await self._run_cleanup_hooks()
Expand Down Expand Up @@ -2895,29 +2896,33 @@ async def get_final_response(self) -> FinalT:
await self._get_stream()
if self._inner_stream is None:
raise RuntimeError("Inner stream not available")
if not self._finalized:
if not self._finalized and not self._consumed:
# Consume outer stream (which delegates to inner) if not already consumed
if not self._consumed:
async for _ in self:
pass
async for _ in self:
pass

# First, finalize the inner stream and run its result hooks
# Re-check: __anext__ auto-finalization may have already finalized this stream
if not self._finalized:
# This ensures inner post-processing (e.g., context provider notifications) runs
if self._inner_stream._finalizer is not None:
inner_result: Any = self._inner_stream._finalizer(self._inner_stream._updates)
if isinstance(inner_result, Awaitable):
inner_result = await inner_result
# Skip if inner stream was already finalized (e.g., via auto-finalization on iteration)
if not self._inner_stream._finalized:
if self._inner_stream._finalizer is not None:
inner_result: Any = self._inner_stream._finalizer(self._inner_stream._updates)
if isinstance(inner_result, Awaitable):
inner_result = await inner_result
else:
inner_result = self._inner_stream._updates
# Run inner stream's result hooks
for hook in self._inner_stream._result_hooks:
hooked = hook(inner_result)
if isinstance(hooked, Awaitable):
hooked = await hooked
if hooked is not None:
inner_result = hooked
self._inner_stream._final_result = inner_result
self._inner_stream._finalized = True
else:
inner_result = self._inner_stream._updates
# Run inner stream's result hooks
for hook in self._inner_stream._result_hooks:
hooked = hook(inner_result)
if isinstance(hooked, Awaitable):
hooked = await hooked
if hooked is not None:
inner_result = hooked
self._inner_stream._final_result = inner_result
self._inner_stream._finalized = True
inner_result = self._inner_stream._final_result

# Now finalize the outer stream with its own finalizer
# If outer has no finalizer, use inner's result (preserves from_awaitable behavior)
Expand All @@ -2938,11 +2943,11 @@ async def get_final_response(self) -> FinalT:
self._final_result = result
self._finalized = True
return self._final_result # type: ignore[return-value]
if not self._finalized and not self._consumed:
async for _ in self:
pass
# Re-check: __anext__ auto-finalization may have already finalized this stream
if not self._finalized:
if not self._consumed:
async for _ in self:
pass
# Use finalizer if configured, otherwise return collected updates
if self._finalizer is not None:
result = self._finalizer(self._updates)
if isinstance(result, Awaitable):
Expand Down
34 changes: 34 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,40 @@ async def test_chat_client_agent_streaming_session_id_set_without_get_final_resp
assert session.service_session_id == "resp_123"


async def test_chat_client_agent_streaming_session_history_saved_without_get_final_response(
chat_client_base: SupportsChatGetResponse,
) -> None:
"""Test that session history is saved after streaming iteration without get_final_response().

Auto-finalization on iteration completion should trigger after_run providers,
persisting conversation history to the session.
"""
from agent_framework._sessions import InMemoryHistoryProvider

chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[Content.from_text("Hello Alice!")],
role="assistant",
response_id="resp_1",
finish_reason="stop",
),
]
]

agent = Agent(client=chat_client_base)
session = agent.create_session()

# Only iterate — do NOT call get_final_response()
async for _ in agent.run("My name is Alice", session=session, stream=True):
pass

chat_messages: list[Message] = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get("messages", [])
assert len(chat_messages) == 2
assert chat_messages[0].text == "My name is Alice"
assert chat_messages[1].text == "Hello Alice!"


async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
from agent_framework._sessions import InMemoryHistoryProvider

Expand Down
52 changes: 52 additions & 0 deletions python/packages/core/tests/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2658,6 +2658,58 @@ async def test_updates_property_returns_collected(self) -> None:
assert stream.updates[0].text == "update_0"
assert stream.updates[1].text == "update_1"

async def test_auto_finalize_on_iteration_completion(self) -> None:
"""Stream auto-finalizes when async iteration completes."""
stream = ResponseStream(_generate_updates(2), finalizer=_combine_updates)

async for _ in stream:
pass

assert stream._finalized is True
assert stream._final_result is not None
assert stream._final_result.text == "update_0update_1"

async def test_auto_finalize_runs_result_hooks(self) -> None:
"""Result hooks run automatically when iteration completes."""
hook_called = {"value": False}

def tracking_hook(response: ChatResponse) -> ChatResponse:
hook_called["value"] = True
response.additional_properties["auto_finalized"] = True
return response

stream = ResponseStream(
_generate_updates(2),
finalizer=_combine_updates,
result_hooks=[tracking_hook],
)

async for _ in stream:
pass

assert hook_called["value"] is True
final = await stream.get_final_response()
assert final.additional_properties["auto_finalized"] is True

async def test_get_final_response_idempotent_after_auto_finalize(self) -> None:
"""get_final_response returns cached result after auto-finalization."""
call_count = {"value": 0}

def counting_finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse:
call_count["value"] += 1
return _combine_updates(updates)

stream = ResponseStream(_generate_updates(2), finalizer=counting_finalizer)

async for _ in stream:
pass

final1 = await stream.get_final_response()
final2 = await stream.get_final_response()

assert call_count["value"] == 1
assert final1.text == final2.text


class TestResponseStreamTransformHooks:
"""Tests for transform hooks (per-update processing)."""
Expand Down
51 changes: 13 additions & 38 deletions python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Literal, overload

import pytest

from agent_framework import (
AgentExecutor,
AgentResponse,
Expand Down Expand Up @@ -59,30 +60,19 @@ def run(
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.call_count += 1
if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[
Content.from_text(
text=f"Response #{self.call_count}: {self.name}"
)
]
contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")]
)

return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)

async def _run() -> AgentResponse:
return AgentResponse(
messages=[
Message("assistant", [f"Response #{self.call_count}: {self.name}"])
]
)
return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])])

return _run()

Expand Down Expand Up @@ -120,10 +110,7 @@ def run(
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
Expand All @@ -138,19 +125,17 @@ async def _mark_result_hook_called(
self.result_hook_called = True
return response

return ResponseStream(
_stream(), finalizer=AgentResponse.from_updates
).with_result_hook(_mark_result_hook_called)
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
_mark_result_hook_called
)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["hook test"])])

return _run()


async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> (
None
):
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
Expand Down Expand Up @@ -217,9 +202,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:

executor_state = executor_states[executor.id] # type: ignore[index]
assert "cache" in executor_state, "Checkpoint should store executor cache state"
assert "agent_session" in executor_state, (
"Checkpoint should store executor session state"
)
assert "agent_session" in executor_state, "Checkpoint should store executor session state"

# Verify session state structure
session_state = executor_state["agent_session"] # type: ignore[index]
Expand All @@ -240,15 +223,11 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert restored_agent.call_count == 0

# Build new workflow with the restored executor
wf_resume = SequentialBuilder(
participants=[restored_executor], checkpoint_storage=storage
).build()
wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build()

# Resume from checkpoint
resumed_output: AgentExecutorResponse | None = None
async for ev in wf_resume.run(
checkpoint_id=restore_checkpoint.checkpoint_id, stream=True
):
async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True):
if ev.type == "output":
resumed_output = ev.data # type: ignore[assignment]
if ev.type == "status" and ev.state in (
Expand Down Expand Up @@ -391,11 +370,7 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
assert options is not None
assert options["additional_function_arguments"]["custom"] == 1

warned_keys = {
r.message.split("'")[1]
for r in caplog.records
if "reserved" in r.message.lower()
}
warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
assert warned_keys == {"session", "stream", "messages"}


Expand Down
27 changes: 24 additions & 3 deletions python/packages/core/tests/workflow/test_agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,31 @@ def __init__(self, agent_id: str, name: str | None = None) -> None:
self.description: str | None = None

@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def create_session(self, **kwargs: Any) -> AgentSession:
"""Creates a new conversation session for the agent."""
Expand Down
3 changes: 1 addition & 2 deletions python/packages/core/tests/workflow/test_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from typing import Any
from unittest.mock import patch

from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

import pytest
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from agent_framework import (
Executor,
Expand Down
Loading