diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index ee0e813d27..e6eeeb439c 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -2844,6 +2844,7 @@ async def __anext__(self) -> UpdateT: except StopAsyncIteration: self._consumed = True await self._run_cleanup_hooks() + await self.get_final_response() raise except Exception: await self._run_cleanup_hooks() @@ -2895,29 +2896,33 @@ async def get_final_response(self) -> FinalT: await self._get_stream() if self._inner_stream is None: raise RuntimeError("Inner stream not available") - if not self._finalized: + if not self._finalized and not self._consumed: # Consume outer stream (which delegates to inner) if not already consumed - if not self._consumed: - async for _ in self: - pass + async for _ in self: + pass - # First, finalize the inner stream and run its result hooks + # Re-check: __anext__ auto-finalization may have already finalized this stream + if not self._finalized: # This ensures inner post-processing (e.g., context provider notifications) runs - if self._inner_stream._finalizer is not None: - inner_result: Any = self._inner_stream._finalizer(self._inner_stream._updates) - if isinstance(inner_result, Awaitable): - inner_result = await inner_result + # Skip if inner stream was already finalized (e.g., via auto-finalization on iteration) + if not self._inner_stream._finalized: + if self._inner_stream._finalizer is not None: + inner_result: Any = self._inner_stream._finalizer(self._inner_stream._updates) + if isinstance(inner_result, Awaitable): + inner_result = await inner_result + else: + inner_result = self._inner_stream._updates + # Run inner stream's result hooks + for hook in self._inner_stream._result_hooks: + hooked = hook(inner_result) + if isinstance(hooked, Awaitable): + hooked = await hooked + if hooked is not None: + inner_result = hooked + self._inner_stream._final_result = inner_result + self._inner_stream._finalized = True else: - inner_result = self._inner_stream._updates - # Run inner stream's result hooks - for hook in self._inner_stream._result_hooks: - hooked = hook(inner_result) - if isinstance(hooked, Awaitable): - hooked = await hooked - if hooked is not None: - inner_result = hooked - self._inner_stream._final_result = inner_result - self._inner_stream._finalized = True + inner_result = self._inner_stream._final_result # Now finalize the outer stream with its own finalizer # If outer has no finalizer, use inner's result (preserves from_awaitable behavior) @@ -2938,11 +2943,11 @@ async def get_final_response(self) -> FinalT: self._final_result = result self._finalized = True return self._final_result # type: ignore[return-value] + if not self._finalized and not self._consumed: + async for _ in self: + pass + # Re-check: __anext__ auto-finalization may have already finalized this stream if not self._finalized: - if not self._consumed: - async for _ in self: - pass - # Use finalizer if configured, otherwise return collected updates if self._finalizer is not None: result = self._finalizer(self._updates) if isinstance(result, Awaitable): diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index c8d2d9bf8b..e8678c1ff4 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -357,6 +357,40 @@ async def test_chat_client_agent_streaming_session_id_set_without_get_final_resp assert session.service_session_id == "resp_123" +async def test_chat_client_agent_streaming_session_history_saved_without_get_final_response( + chat_client_base: SupportsChatGetResponse, +) -> None: + """Test that session history is saved after streaming iteration without get_final_response(). + + Auto-finalization on iteration completion should trigger after_run providers, + persisting conversation history to the session. + """ + from agent_framework._sessions import InMemoryHistoryProvider + + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[Content.from_text("Hello Alice!")], + role="assistant", + response_id="resp_1", + finish_reason="stop", + ), + ] + ] + + agent = Agent(client=chat_client_base) + session = agent.create_session() + + # Only iterate — do NOT call get_final_response() + async for _ in agent.run("My name is Alice", session=session, stream=True): + pass + + chat_messages: list[Message] = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get("messages", []) + assert len(chat_messages) == 2 + assert chat_messages[0].text == "My name is Alice" + assert chat_messages[1].text == "Hello Alice!" + + async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None: from agent_framework._sessions import InMemoryHistoryProvider diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index bcf3a6891b..5f23caa23a 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -2658,6 +2658,58 @@ async def test_updates_property_returns_collected(self) -> None: assert stream.updates[0].text == "update_0" assert stream.updates[1].text == "update_1" + async def test_auto_finalize_on_iteration_completion(self) -> None: + """Stream auto-finalizes when async iteration completes.""" + stream = ResponseStream(_generate_updates(2), finalizer=_combine_updates) + + async for _ in stream: + pass + + assert stream._finalized is True + assert stream._final_result is not None + assert stream._final_result.text == "update_0update_1" + + async def test_auto_finalize_runs_result_hooks(self) -> None: + """Result hooks run automatically when iteration completes.""" + hook_called = {"value": False} + + def tracking_hook(response: ChatResponse) -> ChatResponse: + hook_called["value"] = True + response.additional_properties["auto_finalized"] = True + return response + + stream = ResponseStream( + _generate_updates(2), + finalizer=_combine_updates, + result_hooks=[tracking_hook], + ) + + async for _ in stream: + pass + + assert hook_called["value"] is True + final = await stream.get_final_response() + assert final.additional_properties["auto_finalized"] is True + + async def test_get_final_response_idempotent_after_auto_finalize(self) -> None: + """get_final_response returns cached result after auto-finalization.""" + call_count = {"value": 0} + + def counting_finalizer(updates: list[ChatResponseUpdate]) -> ChatResponse: + call_count["value"] += 1 + return _combine_updates(updates) + + stream = ResponseStream(_generate_updates(2), finalizer=counting_finalizer) + + async for _ in stream: + pass + + final1 = await stream.get_final_response() + final2 = await stream.get_final_response() + + assert call_count["value"] == 1 + assert final1.text == final2.text + class TestResponseStreamTransformHooks: """Tests for transform hooks (per-update processing).""" diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 788e96e61e..599e62d635 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload import pytest + from agent_framework import ( AgentExecutor, AgentResponse, @@ -59,30 +60,19 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.call_count += 1 if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[ - Content.from_text( - text=f"Response #{self.call_count}: {self.name}" - ) - ] + contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")] ) return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) async def _run() -> AgentResponse: - return AgentResponse( - messages=[ - Message("assistant", [f"Response #{self.call_count}: {self.name}"]) - ] - ) + return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])]) return _run() @@ -120,10 +110,7 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -138,9 +125,9 @@ async def _mark_result_hook_called( self.result_hook_called = True return response - return ResponseStream( - _stream(), finalizer=AgentResponse.from_updates - ).with_result_hook(_mark_result_hook_called) + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook( + _mark_result_hook_called + ) async def _run() -> AgentResponse: return AgentResponse(messages=[Message("assistant", ["hook test"])]) @@ -148,9 +135,7 @@ async def _run() -> AgentResponse: return _run() -async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> ( - None -): +async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None: """AgentExecutor should call get_final_response() so stream result hooks execute.""" agent = _StreamingHookAgent(id="hook_agent", name="HookAgent") executor = AgentExecutor(agent, id="hook_exec") @@ -217,9 +202,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: executor_state = executor_states[executor.id] # type: ignore[index] assert "cache" in executor_state, "Checkpoint should store executor cache state" - assert "agent_session" in executor_state, ( - "Checkpoint should store executor session state" - ) + assert "agent_session" in executor_state, "Checkpoint should store executor session state" # Verify session state structure session_state = executor_state["agent_session"] # type: ignore[index] @@ -240,15 +223,11 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert restored_agent.call_count == 0 # Build new workflow with the restored executor - wf_resume = SequentialBuilder( - participants=[restored_executor], checkpoint_storage=storage - ).build() + wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build() # Resume from checkpoint resumed_output: AgentExecutorResponse | None = None - async for ev in wf_resume.run( - checkpoint_id=restore_checkpoint.checkpoint_id, stream=True - ): + async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True): if ev.type == "output": resumed_output = ev.data # type: ignore[assignment] if ev.type == "status" and ev.state in ( @@ -391,11 +370,7 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( assert options is not None assert options["additional_function_arguments"]["custom"] == 1 - warned_keys = { - r.message.split("'")[1] - for r in caplog.records - if "reserved" in r.message.lower() - } + warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} assert warned_keys == {"session", "stream", "messages"} diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index 07d1e64c08..633ba1072c 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -16,10 +16,31 @@ def __init__(self, agent_id: str, name: str | None = None) -> None: self.description: str | None = None @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session for the agent.""" diff --git a/python/packages/core/tests/workflow/test_edge.py b/python/packages/core/tests/workflow/test_edge.py index ecaa341726..422d530631 100644 --- a/python/packages/core/tests/workflow/test_edge.py +++ b/python/packages/core/tests/workflow/test_edge.py @@ -4,9 +4,8 @@ from typing import Any from unittest.mock import patch -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - import pytest +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from agent_framework import ( Executor, diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 77827c0634..77777e198b 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -3,6 +3,8 @@ from dataclasses import dataclass import pytest +from typing_extensions import Never + from agent_framework import ( Executor, Message, @@ -14,7 +16,6 @@ handler, response_handler, ) -from typing_extensions import Never # Module-level types for string forward reference tests @@ -155,11 +156,7 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, collector).build() events = await workflow.run("hello world") - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] assert len(invoked_events) == 2 @@ -193,16 +190,10 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: sender = MultiSenderExecutor(id="sender") collector = CollectorExecutor(id="collector") - workflow = ( - WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() - ) + workflow = WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() events = await workflow.run("hello") - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] # Sender should have completed with the sent messages sender_completed = next(e for e in completed_events if e.executor_id == "sender") @@ -210,9 +201,7 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: assert sender_completed.data == ["hello-first", "hello-second"] # Collector should have completed with no sent messages (None) - collector_completed_events = [ - e for e in completed_events if e.executor_id == "collector" - ] + collector_completed_events = [e for e in completed_events if e.executor_id == "collector"] # Collector is called twice (once per message from sender) assert len(collector_completed_events) == 2 for collector_completed in collector_completed_events: @@ -231,11 +220,7 @@ async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: workflow = WorkflowBuilder(start_executor=executor).build() events = await workflow.run("test") - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] assert len(completed_events) == 1 assert completed_events[0].executor_id == "yielder" @@ -263,9 +248,7 @@ class Response: class ProcessorExecutor(Executor): @handler - async def handle( - self, request: Request, ctx: WorkflowContext[Response] - ) -> None: + async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None: response = Response(results=[request.query.upper()] * request.limit) await ctx.send_message(response) @@ -277,23 +260,13 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: processor = ProcessorExecutor(id="processor") collector = CollectorExecutor(id="collector") - workflow = ( - WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() - ) + workflow = WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() input_request = Request(query="hello", limit=3) events = await workflow.run(input_request) - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] # Check processor invoked event has the Request object processor_invoked = next(e for e in invoked_events if e.executor_id == "processor") @@ -302,9 +275,7 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: assert processor_invoked.data.limit == 3 # Check processor completed event has the Response object - processor_completed = next( - e for e in completed_events if e.executor_id == "processor" - ) + processor_completed = next(e for e in completed_events if e.executor_id == "processor") assert processor_completed.data is not None assert len(processor_completed.data) == 1 assert isinstance(processor_completed.data[0], Response) @@ -390,9 +361,7 @@ async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: # Test executor with union workflow output types class UnionWorkflowOutputExecutor(Executor): @handler - async def handle( - self, text: str, ctx: WorkflowContext[int, str | bool] - ) -> None: + async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: pass executor = UnionWorkflowOutputExecutor(id="union_workflow_output") @@ -403,15 +372,11 @@ async def handle( # Test executor with multiple handlers having different workflow output types class MultiHandlerWorkflowExecutor(Executor): @handler - async def handle_string( - self, text: str, ctx: WorkflowContext[int, str] - ) -> None: + async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None: pass @handler - async def handle_number( - self, num: int, ctx: WorkflowContext[bool, float] - ) -> None: + async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: pass executor = MultiHandlerWorkflowExecutor(id="multi_workflow") @@ -465,9 +430,7 @@ async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: pass @response_handler - async def handle_response( - self, original_request: str, response: bool, ctx: WorkflowContext[float] - ) -> None: + async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None: pass executor = RequestResponseExecutor(id="request_response") @@ -574,9 +537,7 @@ async def test_executor_invoked_event_data_not_mutated_by_handler(): """Test that executor_invoked event (type='executor_invoked').data captures original input, not mutated input.""" @executor(id="Mutator") - async def mutator( - messages: list[Message], ctx: WorkflowContext[list[Message]] - ) -> None: + async def mutator(messages: list[Message], ctx: WorkflowContext[list[Message]]) -> None: # The handler mutates the input list by appending new messages original_len = len(messages) messages.append(Message(role="assistant", text="Added by executor")) @@ -591,11 +552,7 @@ async def mutator( events = await workflow.run(input_messages) # Find the invoked event for the Mutator executor - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] assert len(invoked_events) == 1 mutator_invoked = invoked_events[0] @@ -672,12 +629,8 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess] # Verify can_handle - assert exec_instance.can_handle( - WorkflowMessage(data={"key": "value"}, source_id="mock") - ) - assert not exec_instance.can_handle( - WorkflowMessage(data="string", source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock")) + assert not exec_instance.can_handle(WorkflowMessage(data="string", source_id="mock")) def test_handler_with_explicit_union_input_type(self): """Test that explicit union input_type is handled correctly.""" @@ -698,9 +651,7 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) assert exec_instance.can_handle(WorkflowMessage(data=42, source_id="mock")) # Cannot handle float - assert not exec_instance.can_handle( - WorkflowMessage(data=3.14, source_id="mock") - ) + assert not exec_instance.can_handle(WorkflowMessage(data=3.14, source_id="mock")) def test_handler_with_explicit_union_output_type(self): """Test that explicit union output is normalized to a list.""" @@ -776,9 +727,7 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: class OnlyWorkflowOutputExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler(workflow_output=bool) - async def handle( - self, message: str, ctx: WorkflowContext[int, str] - ) -> None: + async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None: pass def test_handler_explicit_input_type_allows_no_message_annotation(self): @@ -803,9 +752,7 @@ async def handle_explicit(self, message, ctx: WorkflowContext) -> None: # type: pass @handler - async def handle_introspected( - self, message: float, ctx: WorkflowContext[bool] - ) -> None: + async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None: pass exec_instance = MixedExecutor(id="mixed") @@ -831,9 +778,7 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n # Should resolve the string to the actual type assert ForwardRefMessage in exec_instance._handlers # pyright: ignore[reportPrivateUsage] - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock")) def test_handler_with_string_forward_reference_union(self): """Test that string forward references work with union types.""" @@ -846,12 +791,8 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n exec_instance = StringUnionExecutor(id="string_union") # Should handle both types - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock") - ) - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock")) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock")) def test_handler_with_string_forward_reference_output_type(self): """Test that string forward references work for output_type.""" @@ -890,9 +831,7 @@ def test_handler_with_explicit_workflow_output_and_output(self): class PrecedenceExecutor(Executor): @handler(input=int, output=float, workflow_output=str) - async def handle( - self, message: int, ctx: WorkflowContext[int, bool] - ) -> None: + async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: pass exec_instance = PrecedenceExecutor(id="precedence") @@ -958,9 +897,7 @@ class StringUnionWorkflowOutputExecutor(Executor): async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - exec_instance = StringUnionWorkflowOutputExecutor( - id="string_union_workflow_output" - ) + exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output") # Should resolve both types from string union assert ForwardRefTypeA in exec_instance.workflow_output_types @@ -971,14 +908,10 @@ def test_handler_fallback_to_introspection_for_workflow_output_type(self): class IntrospectedWorkflowOutputExecutor(Executor): @handler - async def handle( - self, message: str, ctx: WorkflowContext[int, bool] - ) -> None: + async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None: pass - exec_instance = IntrospectedWorkflowOutputExecutor( - id="introspected_workflow_output" - ) + exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output") # Should use introspected types from WorkflowContext[int, bool] assert int in exec_instance.output_types diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index b5a8bb9902..eacf70c6db 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -717,9 +717,23 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession return AgentSession() @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -813,9 +827,23 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession return AgentSession() @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 0850c6b060..d315f75f85 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -52,9 +52,23 @@ def __init__(self, name: str = "test_agent") -> None: self.captured_kwargs = [] @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -90,9 +104,23 @@ def __init__(self, name: str = "options_agent") -> None: self.captured_kwargs = [] @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -475,9 +503,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -538,9 +580,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -605,9 +661,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index 34c7e8c93f..bf2e277d10 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -38,7 +38,9 @@ async def test_executor_failed_and_workflow_failed_events_streaming(): events.append(ev) # executor_failed event (type='executor_failed') should be emitted before workflow failed event - executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [ + e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed" + ] assert executor_failed_events, "executor_failed event should be emitted when start executor fails" assert executor_failed_events[0].executor_id == "f" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK @@ -96,7 +98,9 @@ async def test_executor_failed_event_from_second_executor_in_chain(): events.append(ev) # executor_failed event should be emitted for the failing executor - executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [ + e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed" + ] assert executor_failed_events, "executor_failed event should be emitted when second executor fails" assert executor_failed_events[0].executor_id == "failing" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK diff --git a/python/samples/01-get-started/README.md b/python/samples/01-get-started/README.md index 5ba119e016..e1bae20b32 100644 --- a/python/samples/01-get-started/README.md +++ b/python/samples/01-get-started/README.md @@ -22,7 +22,7 @@ export AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME="gpt-4o" # optional, defaults to |---|------|-------------------| | 1 | [01_hello_agent.py](01_hello_agent.py) | Create your first agent and run it (streaming and non-streaming). | | 2 | [02_add_tools.py](02_add_tools.py) | Define a function tool with `@tool` and attach it to an agent. | -| 3 | [03_multi_turn.py](03_multi_turn.py) | Keep conversation history across turns with `AgentThread`. | +| 3 | [03_multi_turn.py](03_multi_turn.py) | Keep conversation history across turns with `AgentSession`. | | 4 | [04_memory.py](04_memory.py) | Add dynamic context with a custom `ContextProvider`. | | 5 | [05_first_workflow.py](05_first_workflow.py) | Chain executors into a workflow with edges. | | 6 | [06_host_your_agent.py](06_host_your_agent.py) | Host a single agent with Azure Functions. | diff --git a/python/samples/README.md b/python/samples/README.md index 1f353fbc52..fa091b78bc 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -18,7 +18,7 @@ Start with `01-get-started/` and work through the numbered files: 1. **[01_hello_agent.py](./01-get-started/01_hello_agent.py)** — Create and run your first agent 2. **[02_add_tools.py](./01-get-started/02_add_tools.py)** — Add function tools with `@tool` -3. **[03_multi_turn.py](./01-get-started/03_multi_turn.py)** — Multi-turn conversations with `AgentThread` +3. **[03_multi_turn.py](./01-get-started/03_multi_turn.py)** — Multi-turn conversations with `AgentSession` 4. **[04_memory.py](./01-get-started/04_memory.py)** — Agent memory with `ContextProvider` 5. **[05_first_workflow.py](./01-get-started/05_first_workflow.py)** — Build a workflow with executors and edges 6. **[06_host_your_agent.py](./01-get-started/06_host_your_agent.py)** — Host your agent via Azure Functions