diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index 460c0a6d1a..32faa67527 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -44,7 +44,7 @@ packages = ["agent_framework_ag_ui", "agent_framework_ag_ui_examples"] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests/ag_ui"] -pythonpath = ["."] +pythonpath = [".", "tests/ag_ui"] markers = [ "integration: marks tests as integration tests that require external services", ] diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index d86ebb1720..b73eddb8ad 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -4,6 +4,7 @@ import sys from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence +from pathlib import Path from types import SimpleNamespace from typing import Any, Generic, Literal, cast, overload @@ -36,6 +37,13 @@ ResponseFn = Callable[..., Awaitable[ChatResponse]] +def pytest_configure() -> None: + """Ensure this test directory is on sys.path so helper modules can be imported by name.""" + test_dir = str(Path(__file__).resolve().parent) + if test_dir not in sys.path: + sys.path.insert(0, test_dir) + + class StreamingChatClientStub( ChatMiddlewareLayer[OptionsCoT], FunctionInvocationLayer[OptionsCoT], @@ -241,3 +249,83 @@ def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], Stream def stub_agent() -> type[SupportsAgentRun]: """Return the StubAgent class for creating test instances.""" return StubAgent # type: ignore[return-value] + + +# ── Fixtures for golden / integration tests ── + + +@pytest.fixture +def collect_events() -> Callable[..., Any]: + """Return an async helper that collects all events from an async generator.""" + + async def _collect(async_gen: AsyncIterable[Any]) -> list[Any]: + return [event async for event in async_gen] + + return _collect + + +@pytest.fixture +def make_agent_wrapper() -> Callable[..., Any]: + """Factory that builds an AgentFrameworkAgent from a stream function. + + Usage:: + + agent = make_agent_wrapper( + stream_fn=stream_from_updates(updates), + state_schema=..., + ) + events = [e async for e in agent.run(payload)] + """ + from agent_framework_ag_ui import AgentFrameworkAgent + + def _factory( + stream_fn: StreamFn, + *, + state_schema: Any | None = None, + predict_state_config: dict[str, dict[str, str]] | None = None, + require_confirmation: bool = True, + ) -> Any: + client = StreamingChatClientStub(stream_fn) + stub = StubAgent(client=client) + return AgentFrameworkAgent( + agent=stub, + state_schema=state_schema, + predict_state_config=predict_state_config, + require_confirmation=require_confirmation, + ) + + return _factory + + +@pytest.fixture +def make_app() -> Callable[..., Any]: + """Factory that builds a FastAPI app with an AG-UI endpoint. + + Usage:: + + app = make_app(agent_or_wrapper, path="/test") + """ + from fastapi import FastAPI + + from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint + + def _factory( + agent: Any, + *, + path: str = "/", + state_schema: Any | None = None, + predict_state_config: dict[str, dict[str, str]] | None = None, + default_state: dict[str, Any] | None = None, + ) -> FastAPI: + app = FastAPI() + add_agent_framework_fastapi_endpoint( + app, + agent, + path=path, + state_schema=state_schema, + predict_state_config=predict_state_config, + default_state=default_state, + ) + return app + + return _factory diff --git a/python/packages/ag-ui/tests/ag_ui/event_stream.py b/python/packages/ag-ui/tests/ag_ui/event_stream.py new file mode 100644 index 0000000000..a6300c1042 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/event_stream.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""EventStream assertion helper for AG-UI regression tests.""" + +from __future__ import annotations + +from typing import Any + + +class EventStream: + """Wraps a list of AG-UI events with structured assertion methods. + + Usage: + events = [event async for event in agent.run(payload)] + stream = EventStream(events) + stream.assert_bookends() + stream.assert_text_messages_balanced() + """ + + def __init__(self, events: list[Any]) -> None: + self.events = events + + def __len__(self) -> int: + return len(self.events) + + def __iter__(self): + return iter(self.events) + + def types(self) -> list[str]: + """Return ordered list of event type strings.""" + return [self._type_str(e) for e in self.events] + + def get(self, event_type: str) -> list[Any]: + """Filter events matching the given type string.""" + return [e for e in self.events if self._type_str(e) == event_type] + + def first(self, event_type: str) -> Any: + """Return the first event matching the given type, or raise.""" + matches = self.get(event_type) + if not matches: + raise ValueError(f"No event of type {event_type!r} found. Available: {self.types()}") + return matches[0] + + def last(self, event_type: str) -> Any: + """Return the last event matching the given type, or raise.""" + matches = self.get(event_type) + if not matches: + raise ValueError(f"No event of type {event_type!r} found. Available: {self.types()}") + return matches[-1] + + def snapshot(self) -> dict[str, Any]: + """Return the latest StateSnapshotEvent snapshot dict.""" + return self.last("STATE_SNAPSHOT").snapshot + + def messages_snapshot(self) -> list[Any]: + """Return the latest MessagesSnapshotEvent messages list.""" + return self.last("MESSAGES_SNAPSHOT").messages + + # ── Structural assertions ── + + def assert_bookends(self) -> None: + """Assert first event is RUN_STARTED and last is RUN_FINISHED.""" + types = self.types() + assert types, "Event stream is empty" + assert types[0] == "RUN_STARTED", f"Expected RUN_STARTED first, got {types[0]}" + assert types[-1] == "RUN_FINISHED", f"Expected RUN_FINISHED last, got {types[-1]}" + + def assert_has_run_lifecycle(self) -> None: + """Assert RUN_STARTED is first and RUN_FINISHED exists (may not be last). + + Use this instead of assert_bookends() for workflow resume streams where + _drain_open_message() can emit TEXT_MESSAGE_END after RUN_FINISHED. + """ + types = self.types() + assert types, "Event stream is empty" + assert types[0] == "RUN_STARTED", f"Expected RUN_STARTED first, got {types[0]}" + assert "RUN_FINISHED" in types, f"Expected RUN_FINISHED in stream. Types: {types}" + + def assert_strict_types(self, expected: list[str]) -> None: + """Assert exact type sequence match.""" + actual = self.types() + assert actual == expected, f"Event type mismatch.\nExpected: {expected}\nActual: {actual}" + + def assert_ordered_types(self, expected: list[str]) -> None: + """Assert expected types appear as a subsequence (in order, not necessarily contiguous).""" + actual = self.types() + actual_idx = 0 + for expected_type in expected: + found = False + while actual_idx < len(actual): + if actual[actual_idx] == expected_type: + actual_idx += 1 + found = True + break + actual_idx += 1 + if not found: + raise AssertionError( + f"Expected subsequence type {expected_type!r} not found after index {actual_idx}.\n" + f"Expected subsequence: {expected}\n" + f"Actual types: {actual}" + ) + + def assert_text_messages_balanced(self) -> None: + """Assert every TEXT_MESSAGE_START has a matching TEXT_MESSAGE_END with the same message_id.""" + starts: dict[str, int] = {} + ends: set[str] = set() + for i, event in enumerate(self.events): + t = self._type_str(event) + if t == "TEXT_MESSAGE_START": + mid = event.message_id + assert mid not in starts, f"Duplicate TEXT_MESSAGE_START for message_id={mid}" + starts[mid] = i + elif t == "TEXT_MESSAGE_END": + mid = event.message_id + assert mid in starts, f"TEXT_MESSAGE_END for unknown message_id={mid}" + assert mid not in ends, f"Duplicate TEXT_MESSAGE_END for message_id={mid}" + ends.add(mid) + + unclosed = set(starts.keys()) - ends + assert not unclosed, f"Unclosed text messages: {unclosed}" + + def assert_tool_calls_balanced(self) -> None: + """Assert every TOOL_CALL_START has a matching TOOL_CALL_END with the same tool_call_id.""" + starts: dict[str, int] = {} + ends: set[str] = set() + for i, event in enumerate(self.events): + t = self._type_str(event) + if t == "TOOL_CALL_START": + tid = event.tool_call_id + assert tid not in starts, f"Duplicate TOOL_CALL_START for tool_call_id={tid}" + starts[tid] = i + elif t == "TOOL_CALL_END": + tid = event.tool_call_id + assert tid in starts, f"TOOL_CALL_END for unknown tool_call_id={tid}" + assert tid not in ends, f"Duplicate TOOL_CALL_END for tool_call_id={tid}" + ends.add(tid) + + unclosed = set(starts.keys()) - ends + assert not unclosed, f"Unclosed tool calls: {unclosed}" + + def assert_no_run_error(self) -> None: + """Assert no RUN_ERROR events exist.""" + errors = self.get("RUN_ERROR") + if errors: + messages = [getattr(e, "message", str(e)) for e in errors] + raise AssertionError(f"Found {len(errors)} RUN_ERROR event(s): {messages}") + + def assert_has_type(self, event_type: str) -> None: + """Assert at least one event of the given type exists.""" + assert event_type in self.types(), f"Expected {event_type!r} in stream. Available: {self.types()}" + + def assert_message_ids_consistent(self) -> None: + """Assert TEXT_MESSAGE_CONTENT events reference valid, open message_ids.""" + open_messages: set[str] = set() + for event in self.events: + t = self._type_str(event) + if t == "TEXT_MESSAGE_START": + open_messages.add(event.message_id) + elif t == "TEXT_MESSAGE_END": + open_messages.discard(event.message_id) + elif t == "TEXT_MESSAGE_CONTENT": + mid = event.message_id + assert mid in open_messages, f"TEXT_MESSAGE_CONTENT references message_id={mid} which is not open" + + # ── Internal ── + + @staticmethod + def _type_str(event: Any) -> str: + """Extract event type as a plain string.""" + t = getattr(event, "type", None) + if t is None: + return type(event).__name__ + if isinstance(t, str): + return t + return getattr(t, "value", str(t)) diff --git a/python/packages/ag-ui/tests/ag_ui/golden/__init__.py b/python/packages/ag-ui/tests/ag_ui/golden/__init__.py new file mode 100644 index 0000000000..2a50eae894 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/ag-ui/tests/ag_ui/golden/conftest.py b/python/packages/ag-ui/tests/ag_ui/golden/conftest.py new file mode 100644 index 0000000000..c9470fc198 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/conftest.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Conftest for golden tests — ensures parent test dir is importable.""" + +import sys +from pathlib import Path + + +def pytest_configure() -> None: + """Ensure parent test directory is on sys.path for helper module imports.""" + parent_test_dir = str(Path(__file__).resolve().parent.parent) + if parent_test_dir not in sys.path: + sys.path.insert(0, parent_test_dir) diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py new file mode 100644 index 0000000000..00516171c2 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_agentic_chat.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the basic agentic chat scenario.""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream + +from agent_framework_ag_ui import AgentFrameworkAgent + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent(updates=updates) + return AgentFrameworkAgent(agent=stub, **kwargs) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +BASIC_PAYLOAD: dict[str, Any] = { + "thread_id": "thread-chat", + "run_id": "run-chat", + "messages": [{"role": "user", "content": "Hello"}], +} + + +def _text_update(text: str) -> AgentResponseUpdate: + return AgentResponseUpdate(contents=[Content.from_text(text=text)], role="assistant") + + +def _snapshot_role(msg: Any) -> str: + """Extract role string from a snapshot message (Pydantic model or dict).""" + role = getattr(msg, "role", None) or (msg.get("role") if isinstance(msg, dict) else None) + if role is None: + return "" + return str(getattr(role, "value", role)) + + +def _snapshot_content(msg: Any) -> str: + """Extract content string from a snapshot message.""" + content = getattr(msg, "content", None) or (msg.get("content") if isinstance(msg, dict) else "") + return str(content) if content else "" + + +# ── Golden stream tests ── + + +async def test_basic_chat_golden_event_sequence() -> None: + """Assert the exact event type sequence for a single text response.""" + agent = _build_agent([_text_update("Hi there!")]) + stream = await _run(agent, BASIC_PAYLOAD) + + stream.assert_strict_types( + [ + "RUN_STARTED", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + ) + + +async def test_basic_chat_bookends() -> None: + """RUN_STARTED is first, RUN_FINISHED is last.""" + agent = _build_agent([_text_update("reply")]) + stream = await _run(agent, BASIC_PAYLOAD) + stream.assert_bookends() + + +async def test_basic_chat_text_messages_balanced() -> None: + """Every TEXT_MESSAGE_START has a matching TEXT_MESSAGE_END.""" + agent = _build_agent([_text_update("reply")]) + stream = await _run(agent, BASIC_PAYLOAD) + stream.assert_text_messages_balanced() + + +async def test_basic_chat_no_errors() -> None: + """No RUN_ERROR events in a normal flow.""" + agent = _build_agent([_text_update("reply")]) + stream = await _run(agent, BASIC_PAYLOAD) + stream.assert_no_run_error() + + +async def test_basic_chat_message_id_consistency() -> None: + """All text events reference the same message_id.""" + agent = _build_agent([_text_update("reply")]) + stream = await _run(agent, BASIC_PAYLOAD) + + start = stream.first("TEXT_MESSAGE_START") + content = stream.first("TEXT_MESSAGE_CONTENT") + end = stream.first("TEXT_MESSAGE_END") + assert start.message_id == content.message_id == end.message_id + + +async def test_multi_chunk_text_golden_sequence() -> None: + """Streaming multiple chunks produces START + multiple CONTENT + END.""" + agent = _build_agent([_text_update("Hello "), _text_update("world!")]) + stream = await _run(agent, BASIC_PAYLOAD) + + stream.assert_strict_types( + [ + "RUN_STARTED", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + ) + stream.assert_text_messages_balanced() + stream.assert_message_ids_consistent() + + +async def test_messages_snapshot_contains_assistant_reply() -> None: + """MessagesSnapshotEvent includes the assistant's accumulated text.""" + agent = _build_agent([_text_update("Hello there")]) + stream = await _run(agent, BASIC_PAYLOAD) + + snapshot = stream.messages_snapshot() + assistant_msgs = [m for m in snapshot if _snapshot_role(m) == "assistant"] + assert assistant_msgs, "No assistant message in snapshot" + assert any("Hello there" in _snapshot_content(m) for m in assistant_msgs) + + +async def test_empty_messages_produces_start_and_finish() -> None: + """Empty message list still produces RUN_STARTED and RUN_FINISHED.""" + agent = _build_agent([_text_update("reply")]) + payload = {"thread_id": "t1", "run_id": "r1", "messages": []} + stream = await _run(agent, payload) + + stream.assert_bookends() + assert "TEXT_MESSAGE_START" not in stream.types() diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py new file mode 100644 index 0000000000..7b48740cad --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_backend_tools.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the backend (server-side) tools scenario.""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream + +from agent_framework_ag_ui import AgentFrameworkAgent + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent(updates=updates) + return AgentFrameworkAgent(agent=stub, **kwargs) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-tools", + "run_id": "run-tools", + "messages": [{"role": "user", "content": "What's the weather?"}], +} + + +# ── Golden stream tests ── + + +async def test_tool_call_lifecycle_golden_sequence() -> None: + """Assert the full event sequence for a tool call → result → text response.""" + updates = [ + # LLM calls the tool + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments='{"city": "SF"}')], + role="assistant", + ), + # Tool result comes back + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F and sunny")], + role="assistant", + ), + # LLM responds with text + AgentResponseUpdate( + contents=[Content.from_text(text="It's 72°F and sunny in SF!")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_ordered_types( + [ + "RUN_STARTED", + "TEXT_MESSAGE_START", # Synthetic start for tool-only message + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TOOL_CALL_RESULT", + "TEXT_MESSAGE_END", # End of synthetic message + "TEXT_MESSAGE_START", # New message for text response + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + ) + + +async def test_tool_calls_balanced() -> None: + """Every TOOL_CALL_START has a matching TOOL_CALL_END.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments='{"city": "SF"}')], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's 72°F!")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_tool_calls_balanced() + + +async def test_text_messages_balanced_with_tools() -> None: + """Text messages are properly balanced even around tool calls.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments='{"city": "SF"}')], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's 72°F!")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_text_messages_balanced() + + +async def test_tool_call_id_matches_result() -> None: + """TOOL_CALL_START and TOOL_CALL_RESULT reference the same tool_call_id.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments="{}")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + start = stream.first("TOOL_CALL_START") + result = stream.first("TOOL_CALL_RESULT") + assert start.tool_call_id == result.tool_call_id == "call-1" + + +async def test_tool_result_content_preserved() -> None: + """TOOL_CALL_RESULT event carries the tool's result content.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments="{}")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F and sunny")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + result = stream.first("TOOL_CALL_RESULT") + assert result.content == "72°F and sunny" + + +async def test_no_run_error_on_tool_flow() -> None: + """Tool call flow doesn't produce RUN_ERROR.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments="{}")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_no_run_error() + stream.assert_bookends() + + +async def test_multiple_sequential_tool_calls() -> None: + """Multiple sequential tool calls each produce balanced START/END pairs.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="tool_a", call_id="call-a", arguments="{}")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-a", result="result-a")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_call(name="tool_b", call_id="call-b", arguments="{}")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-b", result="result-b")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="Done!")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_tool_calls_balanced() + stream.assert_text_messages_balanced() + stream.assert_bookends() + + # Both tool calls should appear + starts = stream.get("TOOL_CALL_START") + assert len(starts) == 2 + assert {s.tool_call_name for s in starts} == {"tool_a", "tool_b"} + + +async def test_messages_snapshot_includes_tool_calls() -> None: + """MessagesSnapshotEvent includes tool call and result messages.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments='{"city":"SF"}')], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's warm!")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_has_type("MESSAGES_SNAPSHOT") + snapshot = stream.messages_snapshot() + # Should have: user message, assistant with tool_calls, tool result, assistant text + assert len(snapshot) >= 3 diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py new file mode 100644 index 0000000000..211bbeedc6 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_agent.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the generative UI (workflow-as-agent) scenario.""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import WorkflowBuilder, WorkflowContext, executor +from event_stream import EventStream +from typing_extensions import Never + +from agent_framework_ag_ui import AgentFrameworkWorkflow + + +async def _run(wrapper: AgentFrameworkWorkflow, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in wrapper.run(payload)]) + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-gen-ui-agent", + "run_id": "run-gen-ui-agent", + "messages": [{"role": "user", "content": "Generate a UI"}], +} + + +# ── Golden stream tests ── + + +async def test_workflow_agent_golden_sequence() -> None: + """Workflow-as-agent: emits step events and text content.""" + + @executor(id="generator") + async def generator(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("Here is your generated UI content!") + + workflow = WorkflowBuilder(start_executor=generator).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, PAYLOAD) + + stream.assert_bookends() + stream.assert_no_run_error() + stream.assert_text_messages_balanced() + + # Should have step events for the executor + stream.assert_has_type("STEP_STARTED") + stream.assert_has_type("STEP_FINISHED") + + # Should have text message content + stream.assert_has_type("TEXT_MESSAGE_CONTENT") + + +async def test_workflow_agent_step_names_match() -> None: + """Step started/finished events reference the executor name.""" + + @executor(id="my_executor") + async def my_executor(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("Done!") + + workflow = WorkflowBuilder(start_executor=my_executor).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, PAYLOAD) + + started = [e for e in stream.get("STEP_STARTED") if getattr(e, "step_name", "") == "my_executor"] + finished = [e for e in stream.get("STEP_FINISHED") if getattr(e, "step_name", "") == "my_executor"] + assert started, "Expected STEP_STARTED for 'my_executor'" + assert finished, "Expected STEP_FINISHED for 'my_executor'" + + +async def test_workflow_agent_ordered_events() -> None: + """Workflow events follow expected ordering: RUN_STARTED → STEP_STARTED → content → STEP_FINISHED → RUN_FINISHED.""" + + @executor(id="my_step") + async def my_step(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("Generated content") + + workflow = WorkflowBuilder(start_executor=my_step).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, PAYLOAD) + + stream.assert_ordered_types( + [ + "RUN_STARTED", + "STEP_STARTED", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "STEP_FINISHED", + "TEXT_MESSAGE_END", + "RUN_FINISHED", + ] + ) diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py new file mode 100644 index 0000000000..b154b53236 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_generative_ui_tool.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the client-side (declaration-only) tools scenario.""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream + +from agent_framework_ag_ui import AgentFrameworkAgent + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent(updates=updates) + return AgentFrameworkAgent(agent=stub, **kwargs) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-gen-ui-tool", + "run_id": "run-gen-ui-tool", + "messages": [{"role": "user", "content": "Show me a chart"}], + "tools": [ + { + "type": "function", + "function": { + "name": "render_chart", + "description": "Render a chart in the UI", + "parameters": { + "type": "object", + "properties": {"data": {"type": "array"}}, + }, + }, + } + ], +} + + +# ── Golden stream tests ── + + +async def test_declaration_only_tool_golden_sequence() -> None: + """Declaration-only tool: TOOL_CALL_START/ARGS emitted, TOOL_CALL_END at stream end.""" + # The LLM calls a client-side tool (no server-side execution) + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="render_chart", + call_id="call-chart", + arguments='{"data": [1, 2, 3]}', + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + stream.assert_no_run_error() + + # Tool call start and args should be present + stream.assert_has_type("TOOL_CALL_START") + stream.assert_has_type("TOOL_CALL_ARGS") + + # TOOL_CALL_END should be emitted (via get_pending_without_end) + stream.assert_has_type("TOOL_CALL_END") + stream.assert_tool_calls_balanced() + + +async def test_declaration_only_tool_no_tool_call_result() -> None: + """Declaration-only tools should NOT produce TOOL_CALL_RESULT events.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="render_chart", + call_id="call-chart", + arguments='{"data": [1, 2, 3]}', + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + assert "TOOL_CALL_RESULT" not in stream.types(), "Declaration-only tools should not have TOOL_CALL_RESULT" + + +async def test_declaration_only_tool_text_messages_balanced() -> None: + """Text messages remain balanced even with declaration-only tools.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="render_chart", + call_id="call-chart", + arguments='{"data": [1, 2, 3]}', + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_text_messages_balanced() + + +async def test_declaration_only_tool_messages_snapshot() -> None: + """MessagesSnapshotEvent includes the tool call for declaration-only tools.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="render_chart", + call_id="call-chart", + arguments='{"data": [1, 2, 3]}', + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_has_type("MESSAGES_SNAPSHOT") diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py new file mode 100644 index 0000000000..7af256f625 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_hitl.py @@ -0,0 +1,196 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the HITL (human-in-the-loop) approval scenario.""" + +from __future__ import annotations + +import json +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream + +from agent_framework_ag_ui import AgentFrameworkAgent + +PREDICT_CONFIG = { + "tasks": { + "tool": "generate_task_steps", + "tool_argument": "steps", + } +} + +STATE_SCHEMA = { + "tasks": {"type": "array", "items": {"type": "object"}}, +} + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent(updates=updates) + return AgentFrameworkAgent( + agent=stub, + state_schema=STATE_SCHEMA, + predict_state_config=PREDICT_CONFIG, + require_confirmation=True, + **kwargs, + ) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +STEPS = [ + {"description": "Step 1: Plan", "status": "enabled"}, + {"description": "Step 2: Execute", "status": "enabled"}, +] + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-hitl", + "run_id": "run-hitl", + "messages": [{"role": "user", "content": "Plan my tasks"}], + "state": {"tasks": []}, +} + + +# ── Turn 1: Tool call → confirm_changes → interrupt ── + + +async def test_hitl_turn1_golden_sequence() -> None: + """Turn 1 emits tool call, confirm_changes, and finishes with interrupt.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="generate_task_steps", + call_id="call-steps", + arguments=json.dumps({"steps": STEPS}), + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + # Should have: tool call start/args/end for the primary tool, + # then TOOL_CALL_END, STATE_SNAPSHOT, confirm_changes cycle + stream.assert_bookends() + stream.assert_no_run_error() + + # confirm_changes tool call should be present + tool_starts = stream.get("TOOL_CALL_START") + tool_names = [getattr(s, "tool_call_name", None) for s in tool_starts] + assert "generate_task_steps" in tool_names + assert "confirm_changes" in tool_names + + # RUN_FINISHED should have interrupt metadata + finished = stream.last("RUN_FINISHED") + interrupt = getattr(finished, "interrupt", None) + assert interrupt is not None, "Expected interrupt in RUN_FINISHED" + assert len(interrupt) > 0 + + +async def test_hitl_turn1_tool_calls_balanced() -> None: + """All tool calls in turn 1 (primary + confirm_changes) are balanced.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="generate_task_steps", + call_id="call-steps", + arguments=json.dumps({"steps": STEPS}), + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_tool_calls_balanced() + + +async def test_hitl_turn1_text_messages_balanced() -> None: + """Text messages are balanced even in the approval flow.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="generate_task_steps", + call_id="call-steps", + arguments=json.dumps({"steps": STEPS}), + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_text_messages_balanced() + + +# ── Turn 2: Resume with approval → confirmation message → no interrupt ── + + +async def test_hitl_turn2_resume_with_approval() -> None: + """Resuming with confirm_changes result emits confirmation text and finishes cleanly.""" + # Turn 2: user sends confirm_changes result as resume + # The agent wrapper sees a confirm_changes response and emits a confirmation message + confirm_result = json.dumps( + { + "accepted": True, + "steps": STEPS, + } + ) + + # Build payload with resume containing the approval + # For confirm_changes, the messages should include the tool result + payload: dict[str, Any] = { + "thread_id": "thread-hitl", + "run_id": "run-hitl-2", + "messages": [ + {"role": "user", "content": "Plan my tasks"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "confirm-id-1", + "type": "function", + "function": {"name": "confirm_changes", "arguments": json.dumps({"steps": STEPS})}, + } + ], + }, + { + "role": "tool", + "toolCallId": "confirm-id-1", + "content": confirm_result, + }, + ], + "state": {"tasks": []}, + } + + # In turn 2, the agent sees the confirm_changes result and emits a confirmation text + updates = [ + AgentResponseUpdate( + contents=[Content.from_text(text="Tasks confirmed!")], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, payload) + + stream.assert_bookends() + stream.assert_text_messages_balanced() + stream.assert_no_run_error() + + # Should have text message content (the confirmation message) + text_events = stream.get("TEXT_MESSAGE_CONTENT") + assert text_events, "Expected confirmation text message" + + # RUN_FINISHED should NOT have interrupt (approval completed) + finished = stream.last("RUN_FINISHED") + interrupt = getattr(finished, "interrupt", None) + assert not interrupt, f"Expected no interrupt after approval, got {interrupt}" diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py new file mode 100644 index 0000000000..3870e00728 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_predictive_state.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the predictive state scenario.""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream + +from agent_framework_ag_ui import AgentFrameworkAgent + +PREDICT_CONFIG = { + "document": { + "tool": "update_document", + "tool_argument": "content", + } +} + +STATE_SCHEMA = { + "document": {"type": "string"}, +} + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent(updates=updates) + return AgentFrameworkAgent( + agent=stub, + state_schema=STATE_SCHEMA, + predict_state_config=PREDICT_CONFIG, + require_confirmation=False, + **kwargs, + ) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-predict", + "run_id": "run-predict", + "messages": [{"role": "user", "content": "Write a document"}], + "state": {"document": ""}, +} + + +# ── Golden stream tests ── + + +async def test_predictive_state_emits_deltas_during_tool_args() -> None: + """STATE_DELTA events are emitted as tool arguments stream in.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="update_document", call_id="call-1", arguments="")], + role="assistant", + ), + AgentResponseUpdate( + contents=[ + Content.from_function_call(name="update_document", call_id="call-1", arguments='{"content": "Hello') + ], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_call(name="update_document", call_id="call-1", arguments=' world"}')], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + stream.assert_no_run_error() + + # PredictState custom event should be present + custom_events = stream.get("CUSTOM") + predict_events = [e for e in custom_events if getattr(e, "name", None) == "PredictState"] + assert predict_events, "Expected PredictState custom event" + + # STATE_DELTA events should be emitted during tool arg streaming + assert "STATE_DELTA" in stream.types(), "Expected STATE_DELTA events during predictive streaming" + + +async def test_predictive_state_snapshot_after_tool_end() -> None: + """STATE_SNAPSHOT is emitted when a predictive tool completes (no confirmation).""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="update_document", call_id="call-1", arguments='{"content": "Final text"}' + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + + # Should have initial state snapshot + updated snapshot after tool completion + snapshots = stream.get("STATE_SNAPSHOT") + assert len(snapshots) >= 1, "Expected at least one STATE_SNAPSHOT" + + +async def test_predictive_state_ordered_events() -> None: + """Event ordering: RUN_STARTED → PredictState → STATE_SNAPSHOT → TOOL_CALL_* → STATE_SNAPSHOT → RUN_FINISHED.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_function_call(name="update_document", call_id="call-1", arguments='{"content": "doc"}') + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_ordered_types( + [ + "RUN_STARTED", + "CUSTOM", # PredictState + "STATE_SNAPSHOT", # Initial state + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "RUN_FINISHED", + ] + ) diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py new file mode 100644 index 0000000000..efbe34ed8f --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_shared_state.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the shared state (structured output) scenario.""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from event_stream import EventStream +from pydantic import BaseModel + +from agent_framework_ag_ui import AgentFrameworkAgent + + +class RecipeState(BaseModel): + recipe_title: str = "" + ingredients: list[str] = [] + message: str = "" + + +def _build_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> AgentFrameworkAgent: + stub = StubAgent( + updates=updates, + default_options={"tools": None, "response_format": RecipeState}, + ) + return AgentFrameworkAgent( + agent=stub, + state_schema={ + "recipe_title": {"type": "string"}, + "ingredients": {"type": "array", "items": {"type": "string"}}, + }, + **kwargs, + ) + + +async def _run(agent: AgentFrameworkAgent, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +PAYLOAD: dict[str, Any] = { + "thread_id": "thread-state", + "run_id": "run-state", + "messages": [{"role": "user", "content": "Give me a pasta recipe"}], + "state": {"recipe_title": "", "ingredients": []}, +} + + +# ── Golden stream tests ── + + +async def test_shared_state_emits_state_snapshot() -> None: + """Structured output agent emits STATE_SNAPSHOT with parsed model fields.""" + # The structured output agent gets a response that the framework parses as RecipeState + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_text( + text='{"recipe_title": "Pasta Carbonara", "ingredients": ["pasta", "eggs", "cheese"], "message": "Here is your recipe!"}' + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + stream.assert_bookends() + stream.assert_no_run_error() + + # Should have STATE_SNAPSHOT with the initial state at minimum + stream.assert_has_type("STATE_SNAPSHOT") + + +async def test_shared_state_initial_snapshot_on_first_update() -> None: + """When state_schema and state are provided, initial STATE_SNAPSHOT is emitted after RUN_STARTED.""" + updates = [ + AgentResponseUpdate( + contents=[Content.from_text(text='{"recipe_title": "Test", "ingredients": [], "message": "hi"}')], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + # RUN_STARTED should be followed by STATE_SNAPSHOT (initial state) + stream.assert_ordered_types(["RUN_STARTED", "STATE_SNAPSHOT"]) + + +async def test_shared_state_text_emitted_from_message_field() -> None: + """Structured output's 'message' field is emitted as text message events.""" + updates = [ + AgentResponseUpdate( + contents=[ + Content.from_text( + text='{"recipe_title": "Pasta", "ingredients": ["pasta"], "message": "Enjoy your pasta!"}' + ) + ], + role="assistant", + ), + ] + agent = _build_agent(updates) + stream = await _run(agent, PAYLOAD) + + # Text should be emitted from the message field + text_contents = stream.get("TEXT_MESSAGE_CONTENT") + if text_contents: + combined = "".join(getattr(e, "delta", "") for e in text_contents) + assert "Enjoy your pasta!" in combined diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py new file mode 100644 index 0000000000..61e89057fb --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_subgraphs.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Golden event-stream tests for the workflow HITL (subgraphs) scenario. + +Extends the existing test_subgraphs_example_agent.py with EventStream assertions +on full event ordering, balancing, and interrupt structure. +""" + +from __future__ import annotations + +import json +from typing import Any + +from event_stream import EventStream + +from agent_framework_ag_ui_examples.agents.subgraphs_agent import subgraphs_agent + + +async def _run(agent: Any, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in agent.run(payload)]) + + +# ── Turn 1: Initial request → flight interrupt ── + + +async def test_subgraphs_turn1_golden_bookends() -> None: + """Turn 1 starts with RUN_STARTED and ends with RUN_FINISHED.""" + agent = subgraphs_agent() + stream = await _run( + agent, + { + "thread_id": "thread-sub-golden-1", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip to San Francisco"}], + }, + ) + stream.assert_bookends() + + +async def test_subgraphs_turn1_no_errors() -> None: + """Turn 1 completes without errors.""" + agent = subgraphs_agent() + stream = await _run( + agent, + { + "thread_id": "thread-sub-golden-2", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip"}], + }, + ) + stream.assert_no_run_error() + + +async def test_subgraphs_turn1_has_step_events() -> None: + """Turn 1 emits STEP_STARTED and STEP_FINISHED for workflow executors.""" + agent = subgraphs_agent() + stream = await _run( + agent, + { + "thread_id": "thread-sub-golden-3", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip"}], + }, + ) + stream.assert_has_type("STEP_STARTED") + stream.assert_has_type("STEP_FINISHED") + + +async def test_subgraphs_turn1_interrupt_structure() -> None: + """Turn 1 RUN_FINISHED carries flight interrupt with correct structure.""" + agent = subgraphs_agent() + stream = await _run( + agent, + { + "thread_id": "thread-sub-golden-4", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip to SF"}], + }, + ) + + finished = stream.last("RUN_FINISHED") + interrupt = getattr(finished, "interrupt", None) + assert interrupt is not None, "Expected interrupt in RUN_FINISHED" + assert isinstance(interrupt, list) + assert len(interrupt) > 0 + assert interrupt[0]["value"]["agent"] == "flights" + assert len(interrupt[0]["value"]["options"]) == 2 + + +async def test_subgraphs_turn1_text_messages_balanced() -> None: + """All text messages in turn 1 are properly balanced.""" + agent = subgraphs_agent() + stream = await _run( + agent, + { + "thread_id": "thread-sub-golden-5", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip"}], + }, + ) + stream.assert_text_messages_balanced() + + +async def test_subgraphs_turn1_ordered_flow() -> None: + """Turn 1 event ordering: RUN_STARTED → STATE_SNAPSHOT → STEP_* → TOOL_CALL_* → RUN_FINISHED.""" + agent = subgraphs_agent() + stream = await _run( + agent, + { + "thread_id": "thread-sub-golden-6", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip"}], + }, + ) + stream.assert_ordered_types( + [ + "RUN_STARTED", + "STATE_SNAPSHOT", + "STEP_STARTED", + "RUN_FINISHED", + ] + ) + + +# ── Multi-turn: Flight selection → hotel interrupt → completion ── + + +async def test_subgraphs_full_flow_event_ordering() -> None: + """Complete 3-turn flow maintains proper event ordering throughout.""" + agent = subgraphs_agent() + thread_id = "thread-sub-golden-full" + + # Turn 1 + stream1 = await _run( + agent, + { + "thread_id": thread_id, + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan a trip to SF from Amsterdam"}], + }, + ) + stream1.assert_bookends() + stream1.assert_no_run_error() + + # Extract flight interrupt + finished1 = stream1.last("RUN_FINISHED") + interrupt1 = finished1.model_dump()["interrupt"][0] + + # Turn 2: Select flight + stream2 = await _run( + agent, + { + "thread_id": thread_id, + "run_id": "run-2", + "resume": { + "interrupts": [ + { + "id": interrupt1["id"], + "value": json.dumps( + { + "airline": "United", + "departure": "Amsterdam (AMS)", + "arrival": "San Francisco (SFO)", + "price": "$720", + "duration": "12h 15m", + } + ), + } + ] + }, + }, + ) + stream2.assert_bookends() + stream2.assert_no_run_error() + + # Should now have hotel interrupt + finished2 = stream2.last("RUN_FINISHED") + interrupt2 = finished2.model_dump()["interrupt"] + assert interrupt2[0]["value"]["agent"] == "hotels" + + # Turn 3: Select hotel + stream3 = await _run( + agent, + { + "thread_id": thread_id, + "run_id": "run-3", + "resume": { + "interrupts": [ + { + "id": interrupt2[0]["id"], + "value": json.dumps( + { + "name": "The Ritz-Carlton", + "location": "Nob Hill", + "price_per_night": "$550/night", + "rating": "4.8 stars", + } + ), + } + ] + }, + }, + ) + stream3.assert_bookends() + stream3.assert_no_run_error() + stream3.assert_text_messages_balanced() + + # Final turn should not have interrupt + finished3 = stream3.last("RUN_FINISHED") + final_interrupt = getattr(finished3, "interrupt", None) + assert not final_interrupt, f"Expected no interrupt after completion, got {final_interrupt}" diff --git a/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py new file mode 100644 index 0000000000..5f13b8e67f --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/golden/test_scenario_workflow.py @@ -0,0 +1,962 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Comprehensive golden event-stream tests for AgentFrameworkWorkflow. + +Covers the full matrix of workflow-specific AG-UI patterns: +- request_info → TOOL_CALL lifecycle and balancing +- Executor step events and activity snapshots +- Text output, dict output, BaseEvent passthrough, AgentResponse output +- Text deduplication across workflow outputs +- Workflow error handling → RUN_ERROR +- Multi-turn interrupt/resume round-trips +- Empty turns with pending requests +- Custom workflow events +- Text message draining on request_info and executor boundaries +""" + +import json +from typing import Any, cast + +from ag_ui.core import EventType, StateSnapshotEvent +from agent_framework import ( + AgentResponse, + Content, + Executor, + Message, + WorkflowBuilder, + WorkflowContext, + WorkflowEvent, + executor, + handler, + response_handler, +) +from event_stream import EventStream +from typing_extensions import Never + +from agent_framework_ag_ui import AgentFrameworkWorkflow + + +async def _run(wrapper: AgentFrameworkWorkflow, payload: dict[str, Any]) -> EventStream: + return EventStream([event async for event in wrapper.run(payload)]) + + +def _payload( + msg: str = "go", + *, + thread_id: str = "thread-wf", + run_id: str = "run-wf", + **extra: Any, +) -> dict[str, Any]: + return {"thread_id": thread_id, "run_id": run_id, "messages": [{"role": "user", "content": msg}], **extra} + + +# ────────────────────────────────────────────────────────────────────── +# 1. Basic workflow text output +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_text_output_golden_sequence() -> None: + """Simple text output: RUN_STARTED → STEP_STARTED → TEXT_* → STEP_FINISHED → TEXT_MESSAGE_END → RUN_FINISHED.""" + + @executor(id="greeter") + async def greeter(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("Hello from workflow!") + + workflow = WorkflowBuilder(start_executor=greeter).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + stream.assert_no_run_error() + stream.assert_text_messages_balanced() + stream.assert_has_type("TEXT_MESSAGE_START") + stream.assert_has_type("TEXT_MESSAGE_CONTENT") + stream.assert_has_type("TEXT_MESSAGE_END") + + # Verify actual content + deltas = [e.delta for e in stream.get("TEXT_MESSAGE_CONTENT")] + assert "Hello from workflow!" in deltas + + +async def test_workflow_text_output_message_id_consistency() -> None: + """All text events for a single output share the same message_id.""" + + @executor(id="echo") + async def echo(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("echo reply") + + workflow = WorkflowBuilder(start_executor=echo).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_message_ids_consistent() + + +# ────────────────────────────────────────────────────────────────────── +# 2. Executor step events and activity snapshots +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_executor_lifecycle_events() -> None: + """Executor invocation produces STEP_STARTED, ACTIVITY_SNAPSHOT, STEP_FINISHED.""" + + @executor(id="worker") + async def worker(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("done") + + workflow = WorkflowBuilder(start_executor=worker).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + # Step events with executor ID + started = [e for e in stream.get("STEP_STARTED") if getattr(e, "step_name", "") == "worker"] + finished = [e for e in stream.get("STEP_FINISHED") if getattr(e, "step_name", "") == "worker"] + assert started, "Expected STEP_STARTED for 'worker'" + assert finished, "Expected STEP_FINISHED for 'worker'" + + # Activity snapshots + activities = stream.get("ACTIVITY_SNAPSHOT") + assert activities, "Expected ACTIVITY_SNAPSHOT events" + # Check one of them has executor payload + executor_activities = [a for a in activities if getattr(a, "activity_type", None) == "executor"] + assert executor_activities, "Expected executor-type activity snapshots" + + +async def test_workflow_executor_step_ordering() -> None: + """STEP_STARTED comes before content, STEP_FINISHED comes after.""" + + @executor(id="orderer") + async def orderer(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("ordered output") + + workflow = WorkflowBuilder(start_executor=orderer).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_ordered_types( + [ + "RUN_STARTED", + "STEP_STARTED", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "STEP_FINISHED", + "RUN_FINISHED", + ] + ) + + +# ────────────────────────────────────────────────────────────────────── +# 3. Dict output → CUSTOM workflow_output +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_dict_output_maps_to_custom_event() -> None: + """Non-chat dict output is emitted as CUSTOM workflow_output event.""" + + @executor(id="structured") + async def structured(message: Any, ctx: WorkflowContext[Never, dict[str, int]]) -> None: + await ctx.yield_output({"count": 42, "status": 1}) + + workflow = WorkflowBuilder(start_executor=structured).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + stream.assert_no_run_error() + + customs = [e for e in stream.get("CUSTOM") if getattr(e, "name", None) == "workflow_output"] + assert len(customs) == 1 + assert customs[0].value == {"count": 42, "status": 1} + + # Should NOT have TEXT_MESSAGE events for dict output + assert "TEXT_MESSAGE_CONTENT" not in stream.types() + + +# ────────────────────────────────────────────────────────────────────── +# 4. BaseEvent passthrough +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_base_event_passthrough() -> None: + """AG-UI BaseEvent outputs are yielded directly, not wrapped.""" + + @executor(id="stateful") + async def stateful(message: Any, ctx: WorkflowContext[Never, StateSnapshotEvent]) -> None: + await ctx.yield_output(StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot={"active_agent": "flights"})) + + workflow = WorkflowBuilder(start_executor=stateful).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + snapshots = stream.get("STATE_SNAPSHOT") + assert len(snapshots) == 1 + assert snapshots[0].snapshot["active_agent"] == "flights" + + +# ────────────────────────────────────────────────────────────────────── +# 5. AgentResponse output (conversation payload) +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_agent_response_output_extracts_latest_assistant() -> None: + """AgentResponse output uses only the latest assistant message, not full history.""" + + @executor(id="responder") + async def responder(message: Any, ctx: WorkflowContext[Never, AgentResponse]) -> None: + response = AgentResponse( + messages=[ + Message(role="user", contents=[Content.from_text("My order is damaged")]), + Message(role="assistant", contents=[Content.from_text("I'll process your replacement.")]), + ] + ) + await ctx.yield_output(response) + + workflow = WorkflowBuilder(start_executor=responder).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + stream.assert_text_messages_balanced() + + deltas = [e.delta for e in stream.get("TEXT_MESSAGE_CONTENT")] + assert deltas == ["I'll process your replacement."] + + +# ────────────────────────────────────────────────────────────────────── +# 6. Custom workflow events +# ────────────────────────────────────────────────────────────────────── + + +class ProgressEvent(WorkflowEvent): + """Custom workflow event for testing CUSTOM event mapping.""" + + def __init__(self, progress: int) -> None: + super().__init__("custom_progress", data={"progress": progress}) + + +async def test_workflow_custom_events() -> None: + """Custom workflow events are mapped to CUSTOM AG-UI events.""" + + @executor(id="progress_tracker") + async def progress_tracker(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.add_event(ProgressEvent(25)) + await ctx.yield_output("In progress...") + await ctx.add_event(ProgressEvent(100)) + + workflow = WorkflowBuilder(start_executor=progress_tracker).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + stream.assert_no_run_error() + + progress_events = [e for e in stream.get("CUSTOM") if getattr(e, "name", None) == "custom_progress"] + assert len(progress_events) == 2 + assert progress_events[0].value == {"progress": 25} + assert progress_events[1].value == {"progress": 100} + + +# ────────────────────────────────────────────────────────────────────── +# 7. request_info → TOOL_CALL lifecycle +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_request_info_tool_call_lifecycle() -> None: + """request_info emits TOOL_CALL_START/ARGS/END cycle plus CUSTOM request_info.""" + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info("Need approval", str, request_id="req-1") + + workflow = WorkflowBuilder(start_executor=requester).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + stream.assert_no_run_error() + + # Tool call lifecycle + stream.assert_ordered_types( + [ + "RUN_STARTED", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "CUSTOM", # request_info + "RUN_FINISHED", + ] + ) + + # Verify tool call details + start = stream.first("TOOL_CALL_START") + assert start.tool_call_id == "req-1" + assert start.tool_call_name == "request_info" + + # TOOL_CALL_ARGS should contain the request payload + args = stream.first("TOOL_CALL_ARGS") + assert args.tool_call_id == "req-1" + parsed_args = json.loads(args.delta) + assert parsed_args["request_id"] == "req-1" + + # Tool calls should be balanced + stream.assert_tool_calls_balanced() + + +async def test_workflow_request_info_interrupt_in_run_finished() -> None: + """request_info populates RUN_FINISHED.interrupt with the request metadata.""" + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info( + {"message": "Choose a flight", "options": [{"airline": "KLM"}], "agent": "flights"}, + dict, + request_id="flights-choice", + ) + + workflow = WorkflowBuilder(start_executor=requester).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + finished = stream.last("RUN_FINISHED") + interrupt = finished.model_dump().get("interrupt") + assert isinstance(interrupt, list) + assert len(interrupt) == 1 + assert interrupt[0]["id"] == "flights-choice" + assert interrupt[0]["value"]["agent"] == "flights" + + +async def test_workflow_request_info_emits_interrupt_card_event() -> None: + """request_info with dict data emits a WorkflowInterruptEvent custom event.""" + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info( + {"message": "Pick one", "options": ["A", "B"]}, + dict, + request_id="pick-1", + ) + + workflow = WorkflowBuilder(start_executor=requester).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + interrupt_cards = [e for e in stream.get("CUSTOM") if getattr(e, "name", None) == "WorkflowInterruptEvent"] + assert interrupt_cards, "Expected WorkflowInterruptEvent custom event" + + +# ────────────────────────────────────────────────────────────────────── +# 8. Text message draining on request_info boundary +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_text_drained_before_request_info() -> None: + """Open text message is closed (TEXT_MESSAGE_END) before request_info tool calls begin.""" + + @executor(id="text_then_request") + async def text_then_request(message: Any, ctx: WorkflowContext) -> None: + await ctx.yield_output("Please confirm this action.") + await ctx.request_info("Need approval", str, request_id="approval-1") + + workflow = WorkflowBuilder(start_executor=text_then_request).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_text_messages_balanced() + stream.assert_tool_calls_balanced() + + # TEXT_MESSAGE_END must appear before TOOL_CALL_START + types = stream.types() + text_end_idx = types.index("TEXT_MESSAGE_END") + tool_start_idx = types.index("TOOL_CALL_START") + assert text_end_idx < tool_start_idx, ( + f"TEXT_MESSAGE_END (idx={text_end_idx}) must come before TOOL_CALL_START (idx={tool_start_idx})" + ) + + +# ────────────────────────────────────────────────────────────────────── +# 9. Text deduplication +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_skips_duplicate_text_from_snapshot() -> None: + """Duplicate text from AgentResponse snapshot is not re-emitted.""" + + @executor(id="deduper") + async def deduper(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + text = "Order processed successfully." + await ctx.yield_output(text) + # Snapshot repeats the same text + await ctx.yield_output( + AgentResponse( + messages=[ + Message(role="user", contents=[Content.from_text("process order")]), + Message(role="assistant", contents=[Content.from_text(text)]), + ] + ) + ) + + workflow = WorkflowBuilder(start_executor=deduper).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_text_messages_balanced() + deltas = [e.delta for e in stream.get("TEXT_MESSAGE_CONTENT")] + # Text should appear only once + assert deltas == ["Order processed successfully."] + + +async def test_workflow_skips_consecutive_duplicate_outputs() -> None: + """Consecutive identical text outputs are deduplicated.""" + + @executor(id="repeater") + async def repeater(message: Any, ctx: WorkflowContext[Never, Any]) -> None: + text = "Done!" + await ctx.yield_output(text) + await ctx.yield_output(text) + + workflow = WorkflowBuilder(start_executor=repeater).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_text_messages_balanced() + deltas = [e.delta for e in stream.get("TEXT_MESSAGE_CONTENT")] + assert deltas == ["Done!"] + + +async def test_workflow_emits_distinct_consecutive_outputs() -> None: + """Distinct text outputs are all emitted, not incorrectly deduplicated.""" + + @executor(id="multisayer") + async def multisayer(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("First part. ") + await ctx.yield_output("Second part.") + + workflow = WorkflowBuilder(start_executor=multisayer).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_text_messages_balanced() + deltas = [e.delta for e in stream.get("TEXT_MESSAGE_CONTENT")] + assert deltas == ["First part. ", "Second part."] + + +# ────────────────────────────────────────────────────────────────────── +# 10. Workflow error handling → RUN_ERROR +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_error_emits_run_error_event() -> None: + """Exceptions during workflow streaming produce RUN_ERROR events.""" + + class FailingWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + raise RuntimeError("workflow exploded") + yield # pragma: no cover + + return _stream() + + wrapper = AgentFrameworkWorkflow(workflow=cast(Any, FailingWorkflow())) + stream = await _run(wrapper, _payload()) + + # Should still have RUN_STARTED + stream.assert_has_type("RUN_STARTED") + # Should have RUN_ERROR + stream.assert_has_type("RUN_ERROR") + error = stream.first("RUN_ERROR") + assert "workflow exploded" in error.message + + +async def test_workflow_error_preserves_bookend_structure() -> None: + """Even on error, RUN_STARTED is the first event.""" + + class FailingWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + raise ValueError("bad input") + yield # pragma: no cover + + return _stream() + + wrapper = AgentFrameworkWorkflow(workflow=cast(Any, FailingWorkflow())) + stream = await _run(wrapper, _payload()) + + types = stream.types() + assert types[0] == "RUN_STARTED" + assert "RUN_ERROR" in types + + +# ────────────────────────────────────────────────────────────────────── +# 11. Multi-turn request_info interrupt/resume +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_interrupt_resume_round_trip() -> None: + """Turn 1: request_info → interrupt. Turn 2: resume → completion.""" + + class RequesterExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="requester") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info("Choose an option", str, request_id="choice-1") + + @response_handler + async def handle_choice(self, original: str, response: str, ctx: WorkflowContext) -> None: + await ctx.yield_output(f"You chose: {response}") + + workflow = WorkflowBuilder(start_executor=RequesterExecutor()).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1 + stream1 = await _run(wrapper, _payload(thread_id="thread-resume", run_id="run-1")) + stream1.assert_bookends() + stream1.assert_no_run_error() + stream1.assert_tool_calls_balanced() + + finished1 = stream1.last("RUN_FINISHED") + interrupt1 = finished1.model_dump().get("interrupt") + assert interrupt1, "Expected interrupt" + assert interrupt1[0]["id"] == "choice-1" + + # Turn 2: resume + stream2 = await _run( + wrapper, + { + "thread_id": "thread-resume", + "run_id": "run-2", + "messages": [], + "resume": {"interrupts": [{"id": "choice-1", "value": "Option A"}]}, + }, + ) + stream2.assert_has_run_lifecycle() + stream2.assert_no_run_error() + stream2.assert_text_messages_balanced() + + # Should have the response text + deltas = [e.delta for e in stream2.get("TEXT_MESSAGE_CONTENT")] + assert any("Option A" in d for d in deltas), f"Expected 'Option A' in deltas: {deltas}" + + # No interrupt after resume + finished2 = stream2.last("RUN_FINISHED") + interrupt2 = finished2.model_dump().get("interrupt") + assert not interrupt2 + + +async def test_workflow_forwarded_props_resume() -> None: + """CopilotKit-style forwarded_props.command.resume should resume a pending request.""" + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info({"options": [{"name": "A"}]}, dict, request_id="pick") + + workflow = WorkflowBuilder(start_executor=requester).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1 + await _run(wrapper, _payload(thread_id="thread-fwd", run_id="run-1")) + + # Turn 2 via forwarded_props + stream2 = await _run( + wrapper, + { + "thread_id": "thread-fwd", + "run_id": "run-2", + "messages": [], + "forwarded_props": {"command": {"resume": json.dumps({"name": "A"})}}, + }, + ) + stream2.assert_bookends() + stream2.assert_no_run_error() + + finished = stream2.last("RUN_FINISHED") + assert not finished.model_dump().get("interrupt") + + +# ────────────────────────────────────────────────────────────────────── +# 12. Empty turns with pending requests +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_empty_turn_preserves_interrupts() -> None: + """An empty turn with a pending request still returns the interrupt without errors.""" + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info({"prompt": "choose"}, dict, request_id="pick-one") + + workflow = WorkflowBuilder(start_executor=requester).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1: trigger the request + await _run(wrapper, _payload(thread_id="thread-empty", run_id="run-1")) + + # Turn 2: empty messages, no resume + stream2 = await _run( + wrapper, + { + "thread_id": "thread-empty", + "run_id": "run-2", + "messages": [], + }, + ) + stream2.assert_bookends() + stream2.assert_no_run_error() + stream2.assert_tool_calls_balanced() + + # Should re-emit the pending interrupt + finished = stream2.last("RUN_FINISHED") + interrupts = finished.model_dump().get("interrupt") + assert isinstance(interrupts, list) + assert interrupts[0]["id"] == "pick-one" + + # Should have TOOL_CALL events for the pending request + stream2.assert_has_type("TOOL_CALL_START") + + +async def test_workflow_empty_turn_no_pending_requests() -> None: + """Empty turn with no pending requests produces clean bookends.""" + + @executor(id="noop") + async def noop(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("done") + + workflow = WorkflowBuilder(start_executor=noop).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Run once to completion + await _run(wrapper, _payload(thread_id="thread-empty-clean", run_id="run-1")) + + # Empty turn + stream2 = await _run( + wrapper, + { + "thread_id": "thread-empty-clean", + "run_id": "run-2", + "messages": [], + }, + ) + stream2.assert_bookends() + stream2.assert_no_run_error() + + +# ────────────────────────────────────────────────────────────────────── +# 13. Usage content as CUSTOM event +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_usage_output_maps_to_custom_event() -> None: + """Usage Content outputs are surfaced as custom usage events.""" + + @executor(id="usage_reporter") + async def usage_reporter(message: Any, ctx: WorkflowContext[Never, Content]) -> None: + await ctx.yield_output( + Content.from_usage({"input_token_count": 100, "output_token_count": 50, "total_token_count": 150}) + ) + + workflow = WorkflowBuilder(start_executor=usage_reporter).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + stream = await _run(wrapper, _payload()) + + stream.assert_bookends() + stream.assert_no_run_error() + + usage_events = [e for e in stream.get("CUSTOM") if getattr(e, "name", None) == "usage"] + assert len(usage_events) == 1 + assert usage_events[0].value["input_token_count"] == 100 + assert usage_events[0].value["total_token_count"] == 150 + + +# ────────────────────────────────────────────────────────────────────── +# 14. Approval flow (Content-based request_info) +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_approval_flow_round_trip() -> None: + """function_approval_request via request_info, then resume with approval response.""" + + class ApprovalExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="approval_exec") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + function_call = Content.from_function_call( + call_id="refund-call", + name="submit_refund", + arguments={"order_id": "12345", "amount": "$89.99"}, + ) + approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call) + await ctx.request_info(approval_request, Content, request_id="approval-1") + + @response_handler + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: + status = "approved" if bool(response.approved) else "rejected" + await ctx.yield_output(f"Refund {status}.") + + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1: request approval + stream1 = await _run(wrapper, _payload(thread_id="thread-approval", run_id="run-1")) + stream1.assert_bookends() + stream1.assert_no_run_error() + + finished1 = stream1.last("RUN_FINISHED") + interrupt1 = finished1.model_dump().get("interrupt") + assert interrupt1, "Expected approval interrupt" + interrupt_value = interrupt1[0]["value"] + + # Turn 2: approve + stream2 = await _run( + wrapper, + { + "thread_id": "thread-approval", + "run_id": "run-2", + "messages": [], + "resume": { + "interrupts": [ + { + "id": "approval-1", + "value": { + "type": "function_approval_response", + "approved": True, + "id": interrupt_value.get("id", "approval-1"), + "function_call": interrupt_value.get("function_call"), + }, + } + ] + }, + }, + ) + stream2.assert_has_run_lifecycle() + stream2.assert_no_run_error() + stream2.assert_text_messages_balanced() + + deltas = [e.delta for e in stream2.get("TEXT_MESSAGE_CONTENT")] + assert any("approved" in d for d in deltas) + + # No more interrupt + finished2 = stream2.last("RUN_FINISHED") + assert not finished2.model_dump().get("interrupt") + + +# ────────────────────────────────────────────────────────────────────── +# 15. Message list request/response coercion +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_message_list_resume() -> None: + """Resume with list[Message] payload coerces correctly into workflow response.""" + + class MessageRequestExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="msg_request") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info({"prompt": "Need follow-up"}, list[Message], request_id="handoff") + + @response_handler + async def handle_input(self, original: dict, response: list[Message], ctx: WorkflowContext) -> None: + user_text = response[0].text if response else "" + await ctx.yield_output(f"Got: {user_text}") + + workflow = WorkflowBuilder(start_executor=MessageRequestExecutor()).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1 + await _run(wrapper, _payload(thread_id="thread-msg", run_id="run-1")) + + # Turn 2: resume with message list + stream2 = await _run( + wrapper, + { + "thread_id": "thread-msg", + "run_id": "run-2", + "messages": [], + "resume": { + "interrupts": [ + { + "id": "handoff", + "value": [ + {"role": "user", "contents": [{"type": "text", "text": "Ship a replacement"}]}, + ], + } + ] + }, + }, + ) + stream2.assert_has_run_lifecycle() + stream2.assert_no_run_error() + stream2.assert_text_messages_balanced() + + deltas = [e.delta for e in stream2.get("TEXT_MESSAGE_CONTENT")] + assert any("replacement" in d for d in deltas) + + +# ────────────────────────────────────────────────────────────────────── +# 16. Plain text follow-up does NOT infer interrupt response +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_plain_text_does_not_resume_pending_dict_request() -> None: + """Plain text user follow-up should NOT be coerced into a dict response.""" + + @executor(id="requester") + async def requester(message: Any, ctx: WorkflowContext) -> None: + await ctx.request_info( + {"message": "Choose a flight", "options": [{"airline": "KLM"}], "agent": "flights"}, + dict, + request_id="flights-choice", + ) + + workflow = WorkflowBuilder(start_executor=requester).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1 + await _run(wrapper, _payload(thread_id="thread-nocoerce", run_id="run-1")) + + # Turn 2: plain text follow-up with request_info tool call in history + stream2 = await _run( + wrapper, + { + "thread_id": "thread-nocoerce", + "run_id": "run-2", + "messages": [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "flights-choice", + "type": "function", + "function": {"name": "request_info", "arguments": "{}"}, + } + ], + }, + {"role": "user", "content": "I prefer KLM please"}, + ], + }, + ) + stream2.assert_bookends() + stream2.assert_no_run_error() + + # Should still have the interrupt (text was not accepted as dict response) + finished = stream2.last("RUN_FINISHED") + interrupts = finished.model_dump().get("interrupt") + assert isinstance(interrupts, list) + assert interrupts[0]["id"] == "flights-choice" + + +# ────────────────────────────────────────────────────────────────────── +# 17. Workflow factory (thread-scoped workflows) +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_factory_thread_scoping() -> None: + """workflow_factory creates separate workflow instances per thread_id.""" + + def make_workflow(thread_id: str): + @executor(id="echo") + async def echo(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output(f"Thread: {thread_id}") + + return WorkflowBuilder(start_executor=echo).build() + + wrapper = AgentFrameworkWorkflow(workflow_factory=make_workflow) + + stream_a = await _run(wrapper, _payload(thread_id="thread-a", run_id="run-a")) + stream_b = await _run(wrapper, _payload(thread_id="thread-b", run_id="run-b")) + + stream_a.assert_bookends() + stream_b.assert_bookends() + + deltas_a = [e.delta for e in stream_a.get("TEXT_MESSAGE_CONTENT")] + deltas_b = [e.delta for e in stream_b.get("TEXT_MESSAGE_CONTENT")] + assert any("thread-a" in d for d in deltas_a) + assert any("thread-b" in d for d in deltas_b) + + +# ────────────────────────────────────────────────────────────────────── +# 18. Multiple request_info calls in sequence +# ────────────────────────────────────────────────────────────────────── + + +async def test_workflow_sequential_request_info_interrupts() -> None: + """Two chained executors each requesting info: first triggers interrupt, resume, then second triggers interrupt. + + This mirrors the subgraphs_agent pattern where separate executors handle sequential interactions. + """ + + class NameRequester(Executor): + def __init__(self) -> None: + super().__init__(id="name_requester") + + @handler + async def start(self, message: Any, ctx: WorkflowContext[str]) -> None: + await ctx.request_info("What's your name?", str, request_id="name-req") + + @response_handler + async def handle_name(self, original: str, response: str, ctx: WorkflowContext[str]) -> None: + await ctx.send_message(response) + + class DestRequester(Executor): + def __init__(self) -> None: + super().__init__(id="dest_requester") + + @handler + async def start(self, message: str, ctx: WorkflowContext[str]) -> None: + self._name = message + await ctx.request_info("Where to?", str, request_id="dest-req") + + @response_handler + async def handle_dest(self, original: str, response: str, ctx: WorkflowContext[str]) -> None: + await ctx.yield_output(f"Booking for {self._name} to {response}") + + name_requester = NameRequester() + dest_requester = DestRequester() + workflow = WorkflowBuilder(start_executor=name_requester).add_chain([name_requester, dest_requester]).build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + + # Turn 1 + stream1 = await _run(wrapper, _payload(thread_id="thread-seq", run_id="run-1")) + stream1.assert_bookends() + stream1.assert_tool_calls_balanced() + interrupt1 = stream1.last("RUN_FINISHED").model_dump().get("interrupt") + assert interrupt1[0]["id"] == "name-req" + + # Turn 2: answer name → triggers second executor's request_info + stream2 = await _run( + wrapper, + { + "thread_id": "thread-seq", + "run_id": "run-2", + "messages": [], + "resume": {"interrupts": [{"id": "name-req", "value": "Alice"}]}, + }, + ) + stream2.assert_has_run_lifecycle() + stream2.assert_tool_calls_balanced() + interrupt2 = stream2.last("RUN_FINISHED").model_dump().get("interrupt") + assert interrupt2[0]["id"] == "dest-req" + + # Turn 3: answer destination → completion + stream3 = await _run( + wrapper, + { + "thread_id": "thread-seq", + "run_id": "run-3", + "messages": [], + "resume": {"interrupts": [{"id": "dest-req", "value": "Paris"}]}, + }, + ) + stream3.assert_has_run_lifecycle() + stream3.assert_no_run_error() + stream3.assert_text_messages_balanced() + + deltas = [e.delta for e in stream3.get("TEXT_MESSAGE_CONTENT")] + assert any("Alice" in d and "Paris" in d for d in deltas) + assert not stream3.last("RUN_FINISHED").model_dump().get("interrupt") diff --git a/python/packages/ag-ui/tests/ag_ui/sse_helpers.py b/python/packages/ag-ui/tests/ag_ui/sse_helpers.py new file mode 100644 index 0000000000..8a71dd9afb --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/sse_helpers.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""SSE parsing helpers for AG-UI HTTP round-trip tests.""" + +from __future__ import annotations + +import json +from typing import Any + +from event_stream import EventStream + + +def parse_sse_response(response_content: bytes) -> list[dict[str, Any]]: + """Parse raw SSE bytes from TestClient into a list of event dicts. + + Each SSE event is a ``data: {...}`` line followed by a blank line. + """ + text = response_content.decode("utf-8") + events: list[dict[str, Any]] = [] + decode_errors: list[str] = [] + for line in text.splitlines(): + if line.startswith("data: "): + payload = line[6:] + try: + events.append(json.loads(payload)) + except json.JSONDecodeError as exc: + decode_errors.append(f"payload={payload!r}, error={exc}") + continue + if decode_errors: + joined = "; ".join(decode_errors) + raise AssertionError(f"Failed to decode one or more SSE data lines: {joined}") + return events + + +def parse_sse_to_event_stream(response_content: bytes) -> EventStream: + """Parse SSE bytes and wrap in EventStream for structured assertions. + + Returns an EventStream over lightweight SimpleNamespace objects that + mirror AG-UI event attributes (type, message_id, tool_call_id, etc.) + so that EventStream assertion methods work. + """ + from types import SimpleNamespace + + raw_events = parse_sse_response(response_content) + events: list[Any] = [] + for raw in raw_events: + # Normalize camelCase keys to snake_case attributes that EventStream expects + ns = SimpleNamespace() + ns.type = raw.get("type", "") + ns.raw = raw + # Map common camelCase fields + for camel, snake in _FIELD_MAP.items(): + if camel in raw: + setattr(ns, snake, raw[camel]) + # Also keep camelCase as attributes for direct access + for key, value in raw.items(): + if not hasattr(ns, key): + setattr(ns, key, value) + events.append(ns) + return EventStream(events) + + +_FIELD_MAP: dict[str, str] = { + "messageId": "message_id", + "runId": "run_id", + "threadId": "thread_id", + "toolCallId": "tool_call_id", + "toolCallName": "tool_call_name", + "toolName": "tool_call_name", + "parentMessageId": "parent_message_id", + "stepName": "step_name", +} diff --git a/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py b/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py index b6d2152d2a..df6359b8ba 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/ag_ui/test_ag_ui_client.py @@ -21,7 +21,7 @@ from agent_framework_ag_ui._http_service import AGUIHttpService -class TestableAGUIChatClient(AGUIChatClient): +class StubAGUIChatClient(AGUIChatClient): """Testable wrapper exposing protected helpers.""" @property @@ -53,19 +53,19 @@ class TestAGUIChatClient: async def test_client_initialization(self) -> None: """Test client initialization.""" - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") assert client.http_service is not None assert client.http_service.endpoint.startswith("http://localhost:8888") async def test_client_context_manager(self) -> None: """Test client as async context manager.""" - async with TestableAGUIChatClient(endpoint="http://localhost:8888/") as client: + async with StubAGUIChatClient(endpoint="http://localhost:8888/") as client: assert client is not None async def test_extract_state_from_messages_no_state(self) -> None: """Test state extraction when no state is present.""" - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") messages = [ Message(role="user", text="Hello"), Message(role="assistant", text="Hi there"), @@ -80,7 +80,7 @@ async def test_extract_state_from_messages_with_state(self) -> None: """Test state extraction from last message.""" import base64 - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") state_data = {"key": "value", "count": 42} state_json = json.dumps(state_data) @@ -104,7 +104,7 @@ async def test_extract_state_invalid_json(self) -> None: """Test state extraction with invalid JSON.""" import base64 - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") invalid_json = "not valid json" state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8") @@ -123,7 +123,7 @@ async def test_extract_state_invalid_json(self) -> None: async def test_convert_messages_to_agui_format(self) -> None: """Test message conversion to AG-UI format.""" - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") messages = [ Message(role="user", text="What is the weather?"), Message(role="assistant", text="Let me check.", message_id="msg_123"), @@ -140,7 +140,7 @@ async def test_convert_messages_to_agui_format(self) -> None: async def test_get_thread_id_from_metadata(self) -> None: """Test thread ID extraction from metadata.""" - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"}) thread_id = client.get_thread_id(chat_options) @@ -149,7 +149,7 @@ async def test_get_thread_id_from_metadata(self) -> None: async def test_get_thread_id_generation(self) -> None: """Test automatic thread ID generation.""" - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") chat_options = ChatOptions() thread_id = client.get_thread_id(chat_options) @@ -170,7 +170,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str for event in mock_events: yield event - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", text="Test message")] @@ -203,7 +203,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str for event in mock_events: yield event - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", text="Test message")] @@ -246,7 +246,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str for event in mock_events: yield event - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", text="Test with tools")] @@ -270,7 +270,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str for event in mock_events: yield event - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", text="Test server tool execution")] @@ -312,7 +312,7 @@ async def fake_auto_invoke(*args: object, **kwargs: Any) -> None: monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke) - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", text="Test server tool execution")] @@ -348,7 +348,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str for event in mock_events: yield event - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) chat_options = ChatOptions() @@ -357,6 +357,81 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str assert response is not None + async def test_extract_state_from_empty_messages(self) -> None: + """Empty messages list returns empty list and None state.""" + client = StubAGUIChatClient(endpoint="http://localhost:8888/") + result_messages, state = client.extract_state_from_messages([]) + assert result_messages == [] + assert state is None + + async def test_register_server_tool_non_dict_config(self) -> None: + """Non-dict function_invocation_configuration is a no-op.""" + client = StubAGUIChatClient( + endpoint="http://localhost:8888/", + function_invocation_configuration=None, # type: ignore[arg-type] + ) + # Should not raise + client._register_server_tool_placeholder("some_tool") + + async def test_non_streaming_response(self, monkeypatch: MonkeyPatch) -> None: + """Non-streaming path collects updates into ChatResponse.""" + mock_events = [ + {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, + {"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"}, + {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, + ] + + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: + for event in mock_events: + yield event + + client = StubAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) + + messages = [Message(role="user", text="Test")] + response = await client.inner_get_response(messages=messages, options={}, stream=False) + + assert response is not None + assert len(response.messages) > 0 + + async def test_client_tool_sets_additional_properties(self, monkeypatch: MonkeyPatch) -> None: + """Client tool content gets agui_thread_id additional property.""" + + @tool + def my_tool(param: str) -> str: + """My tool.""" + return "result" + + mock_events = [ + {"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"}, + {"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "my_tool"}, + {"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"param": "test"}'}, + {"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"}, + ] + + async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str, Any], None]: + for event in mock_events: + yield event + + client = StubAGUIChatClient(endpoint="http://localhost:8888/") + monkeypatch.setattr(client.http_service, "post_run", mock_post_run) + + messages = [Message(role="user", text="Test")] + updates: list[ChatResponseUpdate] = [] + async for update in client._inner_get_response(messages=messages, stream=True, options={"tools": [my_tool]}): + updates.append(update) + + # Find the function_call content - it should have agui_thread_id + found = False + for update in updates: + for content in update.contents: + if content.type == "function_call" and content.name == "my_tool": + assert content.additional_properties is not None + assert "agui_thread_id" in content.additional_properties + found = True + break + assert found, "Expected to find function_call content for my_tool" + async def test_interrupt_options_transmission(self, monkeypatch: MonkeyPatch) -> None: """Interrupt option fields are forwarded to the HTTP service.""" available_interrupts = [{"id": "req_1", "type": "request_info"}] @@ -373,7 +448,7 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str for event in mock_events: yield event - client = TestableAGUIChatClient(endpoint="http://localhost:8888/") + client = StubAGUIChatClient(endpoint="http://localhost:8888/") monkeypatch.setattr(client.http_service, "post_run", mock_post_run) messages = [Message(role="user", text="continue")] diff --git a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py index 6b65a6ab51..51ab468b84 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py +++ b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py @@ -550,3 +550,56 @@ async def test_endpoint_without_dependencies_is_accessible(build_chat_client): assert response.status_code == 200 assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + +async def test_endpoint_invalid_agent_type_raises_typeerror(): + """Passing an invalid agent type raises TypeError.""" + app = FastAPI() + + with pytest.raises(TypeError, match="must be SupportsAgentRun"): + add_agent_framework_fastapi_endpoint(app, agent="not_an_agent") # type: ignore[arg-type] + + +async def test_endpoint_encoding_failure_emits_run_error(): + """Event encoding failure emits RUN_ERROR event in the SSE stream.""" + from unittest.mock import patch + + class SimpleWorkflow(AgentFrameworkWorkflow): + async def run(self, input_data: dict[str, Any]): + del input_data + yield RunStartedEvent(run_id="run-1", thread_id="thread-1") + + app = FastAPI() + add_agent_framework_fastapi_endpoint(app, SimpleWorkflow(), path="/encode-fail") + client = TestClient(app) + + with patch("ag_ui.encoder.EventEncoder.encode") as mock_encode: + # First call fails (the RUN_STARTED event), second call succeeds (the error event) + mock_encode.side_effect = [ValueError("encode boom"), 'data: {"type":"RUN_ERROR"}\n\n'] + response = client.post("/encode-fail", json={"messages": [{"role": "user", "content": "go"}]}) + + assert response.status_code == 200 + content = response.content.decode("utf-8") + assert "RUN_ERROR" in content + + +async def test_endpoint_double_encoding_failure_terminates(): + """When both event and error encoding fail, stream terminates gracefully.""" + from unittest.mock import patch + + class SimpleWorkflow(AgentFrameworkWorkflow): + async def run(self, input_data: dict[str, Any]): + del input_data + yield RunStartedEvent(run_id="run-1", thread_id="thread-1") + + app = FastAPI() + add_agent_framework_fastapi_endpoint(app, SimpleWorkflow(), path="/double-fail") + client = TestClient(app) + + with patch("ag_ui.encoder.EventEncoder.encode") as mock_encode: + # Both calls fail - event encode and error event encode + mock_encode.side_effect = ValueError("always fails") + response = client.post("/double-fail", json={"messages": [{"role": "user", "content": "go"}]}) + + # Should still get 200 (SSE stream), just with no events + assert response.status_code == 200 diff --git a/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py b/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py new file mode 100644 index 0000000000..7e4712535c --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/test_http_round_trip.py @@ -0,0 +1,215 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""HTTP round-trip tests: POST → SSE bytes → parse → validate event sequence. + +These tests exercise the full HTTP pipeline using FastAPI TestClient, +parsing the raw SSE byte stream and validating through EventStream assertions. +""" + +from __future__ import annotations + +from typing import Any + +from agent_framework import AgentResponseUpdate, Content, WorkflowBuilder, WorkflowContext, executor +from conftest import StubAgent +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sse_helpers import parse_sse_response, parse_sse_to_event_stream +from typing_extensions import Never + +from agent_framework_ag_ui import AgentFrameworkAgent, AgentFrameworkWorkflow, add_agent_framework_fastapi_endpoint + + +def _build_app_with_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> FastAPI: + stub = StubAgent(updates=updates) + agent = AgentFrameworkAgent(agent=stub, **kwargs) + app = FastAPI() + add_agent_framework_fastapi_endpoint(app, agent) + return app + + +def _build_app_with_workflow(workflow_builder: WorkflowBuilder) -> FastAPI: + workflow = workflow_builder.build() + wrapper = AgentFrameworkWorkflow(workflow=workflow) + app = FastAPI() + add_agent_framework_fastapi_endpoint(app, wrapper) + return app + + +USER_PAYLOAD: dict[str, Any] = { + "messages": [{"role": "user", "content": "Hello"}], + "threadId": "thread-http", + "runId": "run-http", +} + + +# ── Agentic chat SSE round-trip ── + + +def test_agentic_chat_sse_round_trip() -> None: + """Full HTTP round-trip: POST → SSE bytes → parse → validate event sequence.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate(contents=[Content.from_text(text="Hi there!")], role="assistant"), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + stream.assert_text_messages_balanced() + stream.assert_no_run_error() + stream.assert_ordered_types( + [ + "RUN_STARTED", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "MESSAGES_SNAPSHOT", + "RUN_FINISHED", + ] + ) + + +# ── Tool call SSE round-trip ── + + +def test_tool_call_sse_round_trip() -> None: + """Tool call events survive SSE encoding/parsing round-trip.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments='{"city": "SF"}')], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's warm!")], + role="assistant", + ), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + stream.assert_tool_calls_balanced() + stream.assert_text_messages_balanced() + + # Verify tool call details survive SSE encoding + start = stream.first("TOOL_CALL_START") + assert start.tool_call_name == "get_weather" + assert start.tool_call_id == "call-1" + + +# ── SSE encoding fidelity ── + + +def test_sse_event_encoding_fidelity() -> None: + """Every event from agent.run() produces a valid SSE data: line that round-trips.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate(contents=[Content.from_text(text="Hello world")], role="assistant"), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + raw_events = parse_sse_response(response.content) + assert len(raw_events) > 0, "No SSE events parsed" + + # Every event should have a 'type' field + for event in raw_events: + assert "type" in event, f"Event missing 'type': {event}" + + # Event types should include the expected ones + event_types = [e["type"] for e in raw_events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + + +# ── camelCase request field acceptance ── + + +def test_camel_case_request_fields_accepted() -> None: + """Request with camelCase fields (runId, threadId) is correctly parsed.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate(contents=[Content.from_text(text="ok")], role="assistant"), + ] + ) + client = TestClient(app) + response = client.post( + "/", + json={ + "messages": [{"role": "user", "content": "hi"}], + "runId": "camel-run", + "threadId": "camel-thread", + }, + ) + assert response.status_code == 200 + + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + + +# ── Workflow SSE round-trip ── + + +def test_workflow_sse_round_trip() -> None: + """Workflow events survive SSE encoding/parsing.""" + + @executor(id="greeter") + async def greeter(message: Any, ctx: WorkflowContext[Never, str]) -> None: + await ctx.yield_output("Hello from workflow!") + + app = _build_app_with_workflow(WorkflowBuilder(start_executor=greeter)) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + assert response.status_code == 200 + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + stream.assert_no_run_error() + stream.assert_text_messages_balanced() + stream.assert_has_type("STEP_STARTED") + + +# ── Error handling ── + + +def test_empty_messages_returns_valid_sse() -> None: + """Empty messages list still returns a valid SSE stream with bookends.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate(contents=[Content.from_text(text="ok")], role="assistant"), + ] + ) + client = TestClient(app) + response = client.post("/", json={"messages": []}) + + assert response.status_code == 200 + stream = parse_sse_to_event_stream(response.content) + stream.assert_bookends() + + +def test_sse_response_headers() -> None: + """SSE response has correct headers for event streaming.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate(contents=[Content.from_text(text="ok")], role="assistant"), + ] + ) + client = TestClient(app) + response = client.post("/", json=USER_PAYLOAD) + + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.headers.get("cache-control") == "no-cache" diff --git a/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py b/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py index bc1b95ad7d..5227d376bb 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py +++ b/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py @@ -868,6 +868,648 @@ def test_agui_messages_to_snapshot_format_basic(): assert result[1]["content"] == "Hi there" +# ── Tool history sanitization edge cases ── + + +def test_sanitize_multiple_approvals_and_logic(): + """Two function_approval_response contents: True + False → False overall.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + assistant_msg = Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + Content.from_function_call(call_id="c2", name="confirm_changes", arguments='{"function_call_id":"c1"}'), + ], + ) + user_msg = Message( + role="user", + contents=[ + Content.from_function_approval_response( + approved=True, + id="a1", + function_call=Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + ), + Content.from_function_approval_response( + approved=False, + id="a2", + function_call=Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + ), + ], + ) + + result = _sanitize_tool_history([assistant_msg, user_msg]) + # Both approvals should be preserved in user message + assert any(msg.role == "user" for msg in result) + + +def test_sanitize_pending_tool_skip_on_user_followup(): + """User text message after assistant tool call injects synthetic skipped results.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + assistant_msg = Message( + role="assistant", + contents=[Content.from_function_call(call_id="c1", name="get_weather", arguments="{}")], + ) + user_msg = Message( + role="user", + contents=[Content.from_text(text="Actually, never mind")], + ) + + result = _sanitize_tool_history([assistant_msg, user_msg]) + # Should have: assistant, synthetic tool result, user + tool_results = [m for m in result if m.role == "tool"] + assert len(tool_results) == 1 + assert "skipped" in str(tool_results[0].contents[0].result).lower() + + +def test_sanitize_tool_result_clears_pending_confirm(): + """Tool result for pending confirm_changes call_id clears pending state.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + assistant_msg = Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + ], + ) + tool_msg = Message( + role="tool", + contents=[Content.from_function_result(call_id="c1", result="done")], + ) + + result = _sanitize_tool_history([assistant_msg, tool_msg]) + assert len(result) == 2 + assert result[1].role == "tool" + + +def test_sanitize_non_standard_role_resets_state(): + """System message between assistant+user resets pending tool state.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + assistant_msg = Message( + role="assistant", + contents=[Content.from_function_call(call_id="c1", name="get_weather", arguments="{}")], + ) + system_msg = Message(role="system", contents=[Content.from_text(text="System update")]) + user_msg = Message(role="user", contents=[Content.from_text(text="Continue")]) + + result = _sanitize_tool_history([assistant_msg, system_msg, user_msg]) + # System message should reset pending state, so no synthetic tool results + tool_results = [m for m in result if m.role == "tool"] + assert len(tool_results) == 0 + + +def test_sanitize_json_confirm_changes_response(): + """User sends JSON text with 'accepted' after confirm_changes.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + assistant_msg = Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + Content.from_function_call(call_id="c2", name="confirm_changes", arguments='{"function_call_id":"c1"}'), + ], + ) + # Note: confirm_changes is filtered, so c2 won't be in pending_tool_call_ids + # But c1 will remain pending. User message with JSON accepted text doesn't match + # confirm_changes path since pending_confirm_changes_id was reset. + user_msg = Message( + role="user", + contents=[Content.from_text(text=json.dumps({"accepted": True}))], + ) + + result = _sanitize_tool_history([assistant_msg, user_msg]) + # Should still process without errors + assert len(result) >= 1 + + +# ── Deduplication edge cases ── + + +def test_deduplicate_tool_results(): + """Duplicate tool results for same call_id are deduplicated.""" + from agent_framework_ag_ui._message_adapters import _deduplicate_messages + + msg1 = Message(role="tool", contents=[Content.from_function_result(call_id="c1", result="first")]) + msg2 = Message(role="tool", contents=[Content.from_function_result(call_id="c1", result="second")]) + + result = _deduplicate_messages([msg1, msg2]) + assert len(result) == 1 + + +def test_deduplicate_assistant_tool_calls(): + """Duplicate assistant messages with same tool_calls are deduplicated.""" + from agent_framework_ag_ui._message_adapters import _deduplicate_messages + + msg1 = Message( + role="assistant", + contents=[Content.from_function_call(call_id="c1", name="fn", arguments="{}")], + ) + msg2 = Message( + role="assistant", + contents=[Content.from_function_call(call_id="c1", name="fn", arguments="{}")], + ) + + result = _deduplicate_messages([msg1, msg2]) + assert len(result) == 1 + + +def test_deduplicate_general_messages(): + """Duplicate general user messages are deduplicated.""" + from agent_framework_ag_ui._message_adapters import _deduplicate_messages + + msg1 = Message(role="user", contents=[Content.from_text(text="Hello")]) + msg2 = Message(role="user", contents=[Content.from_text(text="Hello")]) + + result = _deduplicate_messages([msg1, msg2]) + assert len(result) == 1 + + +def test_deduplicate_replaces_empty_tool_result(): + """Empty tool result is replaced by later non-empty result.""" + from agent_framework_ag_ui._message_adapters import _deduplicate_messages + + msg1 = Message(role="tool", contents=[Content.from_function_result(call_id="c1", result="")]) + msg2 = Message(role="tool", contents=[Content.from_function_result(call_id="c1", result="actual result")]) + + result = _deduplicate_messages([msg1, msg2]) + assert len(result) == 1 + assert result[0].contents[0].result == "actual result" + + +# ── Multimodal & content conversion edge cases ── + + +def test_convert_agui_content_unknown_source_type_fallback(): + """Unknown source type falls back to url/data/id fields.""" + from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part + + part = { + "type": "image", + "source": {"type": "custom", "url": "https://example.com/img.png"}, + } + result = _parse_multimodal_media_part(part) + assert result is not None + assert result.uri == "https://example.com/img.png" + + +def test_convert_agui_content_data_uri_prefix(): + """base64 data starting with 'data:' is treated as data URI.""" + from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part + + part = { + "type": "image", + "source": {"type": "base64", "data": "data:image/png;base64,abc", "mimeType": "image/png"}, + } + result = _parse_multimodal_media_part(part) + assert result is not None + assert result.uri == "data:image/png;base64,abc" + + +def test_convert_agui_content_binary_id(): + """Source with 'id' field creates ag-ui:// URI.""" + from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part + + part = { + "type": "image", + "source": {"type": "id", "id": "file123"}, + } + result = _parse_multimodal_media_part(part) + assert result is not None + assert result.uri == "ag-ui://binary/file123" + + +def test_convert_agui_content_string_items_in_list(): + """String items in content list create text Content.""" + from agent_framework_ag_ui._message_adapters import _convert_agui_content_to_framework + + result = _convert_agui_content_to_framework(["hello", "world"]) + assert len(result) == 2 + assert result[0].text == "hello" + assert result[1].text == "world" + + +def test_convert_agui_content_non_dict_non_str_items(): + """Non-dict/non-str items in list are stringified.""" + from agent_framework_ag_ui._message_adapters import _convert_agui_content_to_framework + + result = _convert_agui_content_to_framework([123, None]) + assert len(result) == 2 + assert result[0].text == "123" + assert result[1].text == "None" + + +def test_convert_agui_content_unknown_part_type_with_text(): + """Unknown part type with 'text' key extracts the text.""" + from agent_framework_ag_ui._message_adapters import _convert_agui_content_to_framework + + result = _convert_agui_content_to_framework([{"type": "widget", "text": "hi"}]) + assert len(result) == 1 + assert result[0].text == "hi" + + +def test_convert_agui_content_unknown_part_type_without_text(): + """Unknown part type without 'text' key stringifies the dict.""" + from agent_framework_ag_ui._message_adapters import _convert_agui_content_to_framework + + result = _convert_agui_content_to_framework([{"type": "widget", "data": 42}]) + assert len(result) == 1 + assert "widget" in result[0].text + + +def test_convert_agui_content_none(): + """None content returns empty list.""" + from agent_framework_ag_ui._message_adapters import _convert_agui_content_to_framework + + result = _convert_agui_content_to_framework(None) + assert result == [] + + +def test_convert_agui_content_non_str_non_list_non_none(): + """Non-string, non-list, non-None content is stringified.""" + from agent_framework_ag_ui._message_adapters import _convert_agui_content_to_framework + + result = _convert_agui_content_to_framework(42) + assert len(result) == 1 + assert result[0].text == "42" + + +# ── Snapshot normalization edge cases ── + + +def test_snapshot_input_image_to_binary(): + """input_image type is normalized to binary in snapshot.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "user", + "content": [ + {"type": "input_image", "source": {"type": "url", "url": "https://example.com/img.png"}}, + ], + } + ] + ) + assert isinstance(result[0]["content"], list) + assert result[0]["content"][0]["type"] == "binary" + + +def test_snapshot_mime_type_snake_case(): + """mime_type (snake_case) is normalized to mimeType.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Caption", "mime_type": "text/plain"}, + { + "type": "image", + "source": {"type": "url", "url": "https://x.com/a.png", "mime_type": "image/png"}, + }, + ], + } + ] + ) + content = result[0]["content"] + assert isinstance(content, list) + # The text part should have mimeType added + text_part = content[0] + assert text_part.get("mimeType") == "text/plain" + + +def test_snapshot_text_only_list_collapsed(): + """List of only text parts is collapsed to string.""" + result = agui_messages_to_snapshot_format( + [{"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": " World"}]}] + ) + assert result[0]["content"] == "Hello World" + + +def test_snapshot_legacy_binary_data_and_id(): + """Legacy binary part with data and id fields.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Caption"}, + {"type": "binary", "data": "base64data", "id": "file1", "mimeType": "image/png"}, + ], + } + ] + ) + content = result[0]["content"] + assert isinstance(content, list) + binary_part = content[1] + assert binary_part["type"] == "binary" + assert binary_part["data"] == "base64data" + assert binary_part["id"] == "file1" + + +# ── Message conversion edge cases ── + + +def test_agui_tool_message_action_execution_id_fallback(): + """Tool message with actionExecutionId but no tool_call_id.""" + messages = agui_messages_to_agent_framework( + [ + { + "role": "tool", + "content": "result data", + "actionExecutionId": "action_1", + } + ] + ) + assert len(messages) == 1 + assert messages[0].contents[0].type == "function_result" + assert messages[0].contents[0].call_id == "action_1" + + +def test_agui_tool_message_result_key_instead_of_content(): + """Tool message with 'result' key instead of 'content'.""" + messages = agui_messages_to_agent_framework( + [ + { + "role": "tool", + "result": "the result", + "toolCallId": "c1", + } + ] + ) + assert len(messages) == 1 + assert messages[0].contents[0].result == "the result" + + +def test_agui_tool_message_dict_content(): + """Tool message with dict content.""" + messages = agui_messages_to_agent_framework( + [ + { + "role": "tool", + "content": {"key": "value"}, + "toolCallId": "c1", + } + ] + ) + assert len(messages) == 1 + # Dict content as approval check: no 'accepted' key, so it's a regular tool result + assert messages[0].contents[0].type == "function_result" + + +def test_agui_tool_message_list_content(): + """Tool message with list content.""" + messages = agui_messages_to_agent_framework( + [ + { + "role": "tool", + "content": ["item1", "item2"], + "toolCallId": "c1", + } + ] + ) + assert len(messages) == 1 + assert messages[0].contents[0].type == "function_result" + + +def test_agui_action_execution_id_without_role(): + """Message with actionExecutionId but no role maps to tool.""" + messages = agui_messages_to_agent_framework( + [ + { + "actionExecutionId": "action_1", + "result": "tool result", + } + ] + ) + assert len(messages) == 1 + assert messages[0].role == "tool" + assert messages[0].contents[0].call_id == "action_1" + + +def test_agui_non_dict_tool_call_skipped(): + """Non-dict tool_call entries in tool_calls array are skipped.""" + messages = agui_messages_to_agent_framework( + [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + "not_a_dict", + { + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + }, + ], + } + ] + ) + assert len(messages) == 1 + func_calls = [c for c in messages[0].contents if c.type == "function_call"] + assert len(func_calls) == 1 + + +def test_agui_empty_content_default(): + """Message with empty/null content gets default empty text.""" + messages = agui_messages_to_agent_framework([{"role": "user"}]) + assert len(messages) == 1 + assert len(messages[0].contents) == 1 + assert messages[0].contents[0].text == "" + + +def test_agui_dict_tool_msg_without_tool_call_id(): + """Dict tool message missing toolCallId gets empty string.""" + result = agui_messages_to_snapshot_format([{"role": "tool", "content": "result"}]) + assert len(result) == 1 + assert result[0].get("toolCallId") == "" + + +def test_snapshot_argument_serialization_none(): + """None arguments in tool_calls are serialized to empty string.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "c1", "type": "function", "function": {"name": "fn", "arguments": None}}, + ], + } + ] + ) + tc = result[0]["tool_calls"][0] + assert tc["function"]["arguments"] == "" + + +def test_snapshot_argument_serialization_object(): + """Object arguments in tool_calls are JSON-serialized.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "c1", "type": "function", "function": {"name": "fn", "arguments": {"key": "val"}}}, + ], + } + ] + ) + tc = result[0]["tool_calls"][0] + assert tc["function"]["arguments"] == '{"key": "val"}' + + +def test_snapshot_tool_call_id_normalization(): + """tool_call_id is normalized to toolCallId in snapshot.""" + result = agui_messages_to_snapshot_format([{"role": "tool", "content": "result", "tool_call_id": "c1"}]) + assert result[0].get("toolCallId") == "c1" + assert "tool_call_id" not in result[0] + + +def test_agui_to_framework_dict_tool_msg_without_tool_call_id(): + """Dict tool message in agent_framework_messages_to_agui without toolCallId.""" + result = agent_framework_messages_to_agui( + [{"role": "tool", "content": "result"}] # type: ignore[list-item] + ) + assert len(result) == 1 + assert result[0].get("toolCallId") == "" + + +def test_snapshot_none_content(): + """None content is normalized to empty string.""" + result = agui_messages_to_snapshot_format([{"role": "user", "content": None}]) + assert result[0]["content"] == "" + + +def test_sanitize_confirm_changes_with_approval_accepted(): + """Approval for pending confirm_changes creates synthetic result.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + # Create assistant with both a real tool and confirm_changes + assistant_msg = Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + Content.from_function_call(call_id="c2", name="confirm_changes", arguments='{"function_call_id":"c1"}'), + ], + ) + # Note: confirm_changes gets filtered out, so pending_confirm_changes_id becomes None. + # The test verifies the filtering path works without error. + user_msg = Message( + role="user", + contents=[ + Content.from_function_approval_response( + approved=True, + id="a1", + function_call=Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + ), + ], + ) + + result = _sanitize_tool_history([assistant_msg, user_msg]) + # Should process without errors; confirm_changes is filtered from assistant msg + assert len(result) >= 1 + + +def test_sanitize_json_accepted_text_for_pending_confirm(): + """JSON text with 'accepted' field for non-filtered confirm_changes path.""" + from agent_framework_ag_ui._message_adapters import _sanitize_tool_history + + # Create an assistant with a tool call that requires a result + assistant_msg = Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="c1", name="tool_a", arguments="{}"), + ], + ) + # A tool result arrives, then a user message + tool_msg = Message( + role="tool", + contents=[Content.from_function_result(call_id="c1", result="done")], + ) + user_msg = Message( + role="user", + contents=[Content.from_text(text="Continue please")], + ) + + result = _sanitize_tool_history([assistant_msg, tool_msg, user_msg]) + # Should have: assistant, tool result, user + assert len(result) == 3 + + +def test_parse_multimodal_media_part_no_data_no_url(): + """Part with no url, data, or id returns None.""" + from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part + + result = _parse_multimodal_media_part({"type": "image"}) + assert result is None + + +def test_parse_multimodal_media_part_binary_source_type(): + """Source with type='binary' extracts data field.""" + from agent_framework_ag_ui._message_adapters import _parse_multimodal_media_part + + result = _parse_multimodal_media_part( + {"type": "image", "source": {"type": "binary", "data": "data:image/png;base64,abc"}} + ) + assert result is not None + assert result.uri == "data:image/png;base64,abc" + + +def test_snapshot_non_dict_item_in_content_list(): + """Non-dict items in content list are stringified.""" + result = agui_messages_to_snapshot_format([{"role": "user", "content": [42, "text"]}]) + # Text-only after stringification means collapsed to string + assert isinstance(result[0]["content"], str) + + +def test_snapshot_non_dict_tool_call_skipped(): + """Non-dict entries in tool_calls are skipped during argument serialization.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + "not_a_dict", + {"id": "c1", "type": "function", "function": {"name": "fn", "arguments": "{}"}}, + ], + } + ] + ) + # Should not error + assert len(result) == 1 + + +def test_snapshot_tool_call_without_function_payload(): + """tool_call dict without function payload is skipped.""" + result = agui_messages_to_snapshot_format( + [ + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "c1", "type": "function"}], + } + ] + ) + assert len(result) == 1 + + +def test_agui_to_framework_action_name_without_role(): + """Message with actionName but no explicit role maps to tool.""" + messages = agui_messages_to_agent_framework([{"actionName": "get_weather", "result": "Sunny", "toolCallId": "c1"}]) + assert len(messages) == 1 + assert messages[0].role == "tool" + + +def test_agui_to_framework_tool_message_content_none(): + """Tool message with content=None uses result field fallback.""" + messages = agui_messages_to_agent_framework( + [{"role": "tool", "content": None, "result": "fallback_result", "toolCallId": "c1"}] + ) + assert len(messages) == 1 + assert messages[0].contents[0].result == "fallback_result" + + def test_agui_fresh_approval_is_still_processed(): """A fresh approval (no assistant response after it) must still produce function_approval_response. diff --git a/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py b/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py new file mode 100644 index 0000000000..714ce2ce50 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/test_multi_turn.py @@ -0,0 +1,332 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Multi-turn conversation tests: POST → collect events → extract snapshot → POST again. + +These tests catch round-trip fidelity bugs: if MessagesSnapshotEvent produces a +malformed message list, the second turn will fail during normalize_agui_input_messages() +or produce incorrect behavior. +""" + +from __future__ import annotations + +import json +from typing import Any + +from agent_framework import AgentResponseUpdate, Content +from conftest import StubAgent +from fastapi import FastAPI +from fastapi.testclient import TestClient +from sse_helpers import parse_sse_response, parse_sse_to_event_stream + +from agent_framework_ag_ui import AgentFrameworkAgent, add_agent_framework_fastapi_endpoint + + +def _build_app_with_agent(updates: list[AgentResponseUpdate], **kwargs: Any) -> FastAPI: + stub = StubAgent(updates=updates) + agent = AgentFrameworkAgent(agent=stub, **kwargs) + app = FastAPI() + add_agent_framework_fastapi_endpoint(app, agent) + return app + + +def _extract_snapshot_messages(response_content: bytes) -> list[dict[str, Any]]: + """Extract the latest MessagesSnapshotEvent.messages from SSE response bytes.""" + raw_events = parse_sse_response(response_content) + snapshot_msgs: list[dict[str, Any]] | None = None + for event in raw_events: + if event.get("type") == "MESSAGES_SNAPSHOT": + snapshot_msgs = event.get("messages", []) + assert snapshot_msgs is not None, "No MESSAGES_SNAPSHOT event found" + return snapshot_msgs + + +# ── Basic multi-turn chat ── + + +def test_basic_multi_turn_chat() -> None: + """Turn 1: user→assistant. Turn 2: user→assistant with prior history from snapshot.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate(contents=[Content.from_text(text="Hello! How can I help?")], role="assistant"), + ] + ) + client = TestClient(app) + + # Turn 1 + resp1 = client.post( + "/", + json={ + "messages": [{"role": "user", "content": "Hi there"}], + "threadId": "thread-multi", + "runId": "run-1", + }, + ) + assert resp1.status_code == 200 + stream1 = parse_sse_to_event_stream(resp1.content) + stream1.assert_bookends() + stream1.assert_text_messages_balanced() + + # Extract snapshot messages from turn 1 + snapshot_messages = _extract_snapshot_messages(resp1.content) + + # Turn 2: send snapshot messages + new user message + turn2_messages = list(snapshot_messages) + [{"role": "user", "content": "Tell me more"}] + resp2 = client.post( + "/", + json={ + "messages": turn2_messages, + "threadId": "thread-multi", + "runId": "run-2", + }, + ) + assert resp2.status_code == 200 + stream2 = parse_sse_to_event_stream(resp2.content) + stream2.assert_bookends() + stream2.assert_text_messages_balanced() + stream2.assert_no_run_error() + + +# ── Tool call history round-trip ── + + +def test_tool_call_history_round_trips() -> None: + """Turn 1: tool call + result. Turn 2: snapshot messages correctly reconstruct tool history.""" + app = _build_app_with_agent( + [ + AgentResponseUpdate( + contents=[Content.from_function_call(name="get_weather", call_id="call-1", arguments='{"city": "SF"}')], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_function_result(call_id="call-1", result="72°F")], + role="assistant", + ), + AgentResponseUpdate( + contents=[Content.from_text(text="It's warm!")], + role="assistant", + ), + ] + ) + client = TestClient(app) + + # Turn 1 + resp1 = client.post( + "/", + json={ + "messages": [{"role": "user", "content": "What's the weather?"}], + "threadId": "thread-tool-multi", + "runId": "run-1", + }, + ) + assert resp1.status_code == 200 + stream1 = parse_sse_to_event_stream(resp1.content) + stream1.assert_tool_calls_balanced() + + # Extract snapshot and verify it has tool history + snapshot_messages = _extract_snapshot_messages(resp1.content) + roles = [m.get("role") for m in snapshot_messages] + assert "tool" in roles or "assistant" in roles, f"Expected tool/assistant messages in snapshot, got: {roles}" + + # Turn 2: send snapshot + new question + turn2_messages = list(snapshot_messages) + [{"role": "user", "content": "What about tomorrow?"}] + resp2 = client.post( + "/", + json={ + "messages": turn2_messages, + "threadId": "thread-tool-multi", + "runId": "run-2", + }, + ) + assert resp2.status_code == 200 + stream2 = parse_sse_to_event_stream(resp2.content) + stream2.assert_bookends() + stream2.assert_no_run_error() + + +# ── Approval interrupt/resume round-trip ── + + +async def test_approval_interrupt_resume_round_trip() -> None: + """Turn 1: approval request → interrupt with confirm_changes. Turn 2: confirm_changes result → confirmation text. + + The confirm_changes flow uses a specific message format that bypasses the agent + and directly emits a confirmation text message. + """ + from event_stream import EventStream + + steps = [{"description": "Execute task", "status": "enabled"}] + + # Build agent with predictive state and confirmation + stub = StubAgent( + updates=[ + AgentResponseUpdate( + contents=[ + Content.from_function_call( + name="generate_task_steps", + call_id="call-steps", + arguments=json.dumps({"steps": steps}), + ) + ], + role="assistant", + ), + ] + ) + agent = AgentFrameworkAgent( + agent=stub, + state_schema={"tasks": {"type": "array"}}, + predict_state_config={"tasks": {"tool": "generate_task_steps", "tool_argument": "steps"}}, + require_confirmation=True, + ) + + # Turn 1 + events1 = [ + e + async for e in agent.run( + { + "thread_id": "thread-approval-multi", + "run_id": "run-1", + "messages": [{"role": "user", "content": "Plan my tasks"}], + "state": {"tasks": []}, + } + ) + ] + stream1 = EventStream(events1) + stream1.assert_bookends() + stream1.assert_tool_calls_balanced() + + # Should have interrupt with function_approval_request + finished1 = stream1.last("RUN_FINISHED") + interrupt1 = finished1.model_dump().get("interrupt") + assert interrupt1, "Expected interrupt in RUN_FINISHED" + + # Verify confirm_changes tool call was emitted + tool_starts = stream1.get("TOOL_CALL_START") + tool_names = [getattr(s, "tool_call_name", None) for s in tool_starts] + assert "confirm_changes" in tool_names, f"Expected confirm_changes in tool calls, got {tool_names}" + + # Turn 2: Direct confirm_changes response (the way CopilotKit sends it) + # Construct the messages as CopilotKit would - with the confirm_changes tool call + # and a tool result + confirm_tool = [s for s in tool_starts if getattr(s, "tool_call_name", None) == "confirm_changes"][0] + confirm_id = confirm_tool.tool_call_id + confirm_args = None + for e in stream1.get("TOOL_CALL_ARGS"): + if e.tool_call_id == confirm_id: + confirm_args = e.delta + break + + turn2_messages = [ + {"role": "user", "content": "Plan my tasks"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": confirm_id, + "type": "function", + "function": {"name": "confirm_changes", "arguments": confirm_args or "{}"}, + }, + ], + }, + { + "role": "tool", + "toolCallId": confirm_id, + "content": json.dumps({"accepted": True, "steps": steps}), + }, + ] + + events2 = [ + e + async for e in agent.run( + { + "thread_id": "thread-approval-multi", + "run_id": "run-2", + "messages": turn2_messages, + "state": {"tasks": []}, + } + ) + ] + stream2 = EventStream(events2) + stream2.assert_bookends() + stream2.assert_text_messages_balanced() + stream2.assert_no_run_error() + + # Turn 2 should have confirmation text (the approval handler generates it) + text_events = stream2.get("TEXT_MESSAGE_CONTENT") + assert text_events, "Expected confirmation text message in turn 2" + + # Turn 2 should NOT have interrupt (approval completed) + finished2 = stream2.last("RUN_FINISHED") + interrupt2 = finished2.model_dump().get("interrupt") + assert not interrupt2, f"Expected no interrupt after approval, got {interrupt2}" + + +# ── Workflow interrupt/resume round-trip ── +# Note: Workflow tests use async agent.run() directly instead of HTTP TestClient +# because the sync TestClient runs in a different event loop, which conflicts +# with the workflow's asyncio Queue. + + +async def test_workflow_interrupt_resume_round_trip() -> None: + """Turn 1: workflow request_info → interrupt. Turn 2: resume → completion.""" + from event_stream import EventStream + + from agent_framework_ag_ui_examples.agents.subgraphs_agent import subgraphs_agent + + agent = subgraphs_agent() + + # Turn 1: initial request → flight interrupt + events1 = [ + event + async for event in agent.run( + { + "messages": [{"role": "user", "content": "Plan a trip to SF"}], + "thread_id": "thread-wf-multi", + "run_id": "run-1", + } + ) + ] + stream1 = EventStream(events1) + stream1.assert_bookends() + stream1.assert_no_run_error() + + finished1 = stream1.last("RUN_FINISHED") + interrupt1 = finished1.model_dump().get("interrupt") + assert interrupt1, "Expected flight interrupt" + assert interrupt1[0]["value"]["agent"] == "flights" + + # Turn 2: resume with flight selection + events2 = [ + event + async for event in agent.run( + { + "messages": [], + "thread_id": "thread-wf-multi", + "run_id": "run-2", + "resume": { + "interrupts": [ + { + "id": interrupt1[0]["id"], + "value": json.dumps( + { + "airline": "United", + "departure": "Amsterdam (AMS)", + "arrival": "San Francisco (SFO)", + "price": "$720", + "duration": "12h 15m", + } + ), + } + ], + }, + } + ) + ] + stream2 = EventStream(events2) + stream2.assert_bookends() + stream2.assert_no_run_error() + + # Should now have hotel interrupt + finished2 = stream2.last("RUN_FINISHED") + interrupt2 = finished2.model_dump().get("interrupt") + assert interrupt2, "Expected hotel interrupt" + assert interrupt2[0]["value"]["agent"] == "hotels" diff --git a/python/packages/ag-ui/tests/ag_ui/test_run_common.py b/python/packages/ag-ui/tests/ag_ui/test_run_common.py new file mode 100644 index 0000000000..526a3c33c1 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/test_run_common.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for _run_common.py edge cases.""" + +from agent_framework import Content + +from agent_framework_ag_ui._run_common import ( + FlowState, + _emit_tool_result, + _extract_resume_payload, + _normalize_resume_interrupts, +) + + +class TestNormalizeResumeInterrupts: + """Tests for _normalize_resume_interrupts edge cases.""" + + def test_plain_list_of_dicts(self): + """Resume payload as a plain list of interrupt dicts.""" + result = _normalize_resume_interrupts([{"id": "x", "value": "y"}]) + assert result == [{"id": "x", "value": "y"}] + + def test_dict_with_singular_interrupt_key(self): + """Resume dict using 'interrupt' (singular) instead of 'interrupts'.""" + result = _normalize_resume_interrupts({"interrupt": [{"id": "x", "value": "y"}]}) + assert result == [{"id": "x", "value": "y"}] + + def test_dict_without_interrupts_key_wraps_as_candidate(self): + """Resume dict without interrupts/interrupt key wraps the dict itself.""" + result = _normalize_resume_interrupts({"id": "x", "value": "y"}) + assert result == [{"id": "x", "value": "y"}] + + def test_non_dict_items_in_list_are_skipped(self): + """Non-dict items in candidate list are silently skipped.""" + result = _normalize_resume_interrupts([None, "string", {"id": "x", "value": "y"}]) + assert result == [{"id": "x", "value": "y"}] + + def test_items_missing_id_are_skipped(self): + """Dict items without any id field are skipped.""" + result = _normalize_resume_interrupts([{"name": "test"}]) + assert result == [] + + def test_response_key_used_as_value(self): + """'response' key is used as value when 'value' is absent.""" + result = _normalize_resume_interrupts([{"id": "x", "response": "approved"}]) + assert result == [{"id": "x", "value": "approved"}] + + def test_neither_value_nor_response_uses_remaining_fields(self): + """When neither 'value' nor 'response' key exists, remaining fields become value.""" + result = _normalize_resume_interrupts([{"id": "x", "extra": "data", "more": 42}]) + assert result == [{"id": "x", "value": {"extra": "data", "more": 42}}] + + def test_none_payload_returns_empty(self): + """None resume payload returns empty list.""" + assert _normalize_resume_interrupts(None) == [] + + def test_non_dict_non_list_returns_empty(self): + """Non-dict, non-list payload returns empty list.""" + assert _normalize_resume_interrupts(42) == [] + + def test_interrupt_id_key_used_as_id(self): + """interruptId key is accepted as identifier.""" + result = _normalize_resume_interrupts([{"interruptId": "abc", "value": "yes"}]) + assert result == [{"id": "abc", "value": "yes"}] + + def test_tool_call_id_key_used_as_id(self): + """toolCallId key is accepted as identifier.""" + result = _normalize_resume_interrupts([{"toolCallId": "tc1", "value": "done"}]) + assert result == [{"id": "tc1", "value": "done"}] + + +class TestExtractResumePayload: + """Tests for _extract_resume_payload edge cases.""" + + def test_forwarded_props_resume_not_nested_in_command(self): + """forwarded_props.resume (not nested in command) is extracted.""" + result = _extract_resume_payload({"forwarded_props": {"resume": "data"}}) + assert result == "data" + + def test_forwarded_props_not_dict_returns_none(self): + """Non-dict forwarded_props returns None.""" + result = _extract_resume_payload({"forwarded_props": "string"}) + assert result is None + + def test_resume_key_has_priority(self): + """Direct resume key takes priority over forwarded_props.""" + result = _extract_resume_payload({"resume": "direct", "forwarded_props": {"resume": "fp"}}) + assert result == "direct" + + def test_no_resume_at_all(self): + """No resume key anywhere returns None.""" + result = _extract_resume_payload({"messages": []}) + assert result is None + + def test_forwarded_props_camelcase(self): + """camelCase forwardedProps is also supported.""" + result = _extract_resume_payload({"forwardedProps": {"resume": "camel"}}) + assert result == "camel" + + +class TestEmitToolResult: + """Tests for _emit_tool_result edge cases.""" + + def test_tool_result_without_call_id_returns_empty(self): + """Tool result Content without call_id returns empty event list.""" + content = Content.from_function_result(call_id=None, result="some result") + flow = FlowState() + events = _emit_tool_result(content, flow) + assert events == [] + + def test_tool_result_closes_open_text_message(self): + """Tool result closes any open text message (issue #3568 fix).""" + content = Content.from_function_result(call_id="call_1", result="done") + flow = FlowState(message_id="msg_1", accumulated_text="Hello") + events = _emit_tool_result(content, flow) + + event_types = [e.type for e in events] + assert "TOOL_CALL_END" in event_types + assert "TOOL_CALL_RESULT" in event_types + assert "TEXT_MESSAGE_END" in event_types + assert flow.message_id is None + assert flow.accumulated_text == "" diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py index 8497145c56..8ebd8fcaaa 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py @@ -3,12 +3,14 @@ """Tests for native workflow AG-UI runner.""" import json +from enum import Enum from types import SimpleNamespace from typing import Any, cast from ag_ui.core import EventType, StateSnapshotEvent from agent_framework import ( AgentResponse, + AgentResponseUpdate, Content, Executor, Message, @@ -22,8 +24,25 @@ from typing_extensions import Never from agent_framework_ag_ui._workflow_run import ( + _coerce_content, + _coerce_json_value, _coerce_message, + _coerce_message_content, _coerce_response_for_request, + _coerce_responses_for_pending_requests, + _custom_event_value, + _details_code, + _details_message, + _interrupt_entry_for_request_event, + _latest_assistant_contents, + _latest_user_text, + _message_role_value, + _pending_request_events, + _request_payload_from_request_event, + _single_pending_response_from_value, + _text_from_contents, + _workflow_interrupt_event_value, + _workflow_payload_to_contents, run_workflow_stream, ) @@ -677,3 +696,734 @@ async def _stream(): assert "RUN_ERROR" in event_types run_error = next(event for event in events if event.type == "RUN_ERROR") assert "workflow stream exploded" in run_error.message + + +# ── Helper function unit tests ── + + +class TestPendingRequestEvents: + """Tests for _pending_request_events helper.""" + + async def test_no_runner_context(self): + """Workflow without _runner_context returns empty dict.""" + workflow = SimpleNamespace() + result = await _pending_request_events(cast(Any, workflow)) + assert result == {} + + async def test_runner_context_missing_get_pending(self): + """Runner context without get_pending_request_info_events returns empty.""" + workflow = SimpleNamespace(_runner_context=SimpleNamespace()) + result = await _pending_request_events(cast(Any, workflow)) + assert result == {} + + async def test_get_pending_returns_non_dict(self): + """get_pending returning non-dict returns empty dict.""" + + async def get_pending(): + return ["not", "a", "dict"] + + workflow = SimpleNamespace(_runner_context=SimpleNamespace(get_pending_request_info_events=get_pending)) + result = await _pending_request_events(cast(Any, workflow)) + assert result == {} + + +class TestInterruptEntryForRequestEvent: + """Tests for _interrupt_entry_for_request_event helper.""" + + def test_request_id_none(self): + """request_id=None returns None.""" + event = SimpleNamespace(request_id=None) + assert _interrupt_entry_for_request_event(event) is None + + def test_dict_data_used_directly(self): + """Dict data is used as interrupt value.""" + event = SimpleNamespace(request_id="r1", data={"key": "val"}) + result = _interrupt_entry_for_request_event(event) + assert result == {"id": "r1", "value": {"key": "val"}} + + def test_non_dict_data_wrapped(self): + """Non-dict data is wrapped in {data: ...}.""" + event = SimpleNamespace(request_id="r1", data="text") + result = _interrupt_entry_for_request_event(event) + assert result == {"id": "r1", "value": {"data": "text"}} + + +class TestRequestPayloadFromRequestEvent: + """Tests for _request_payload_from_request_event helper.""" + + def test_falsy_request_id_returns_none(self): + """Empty string request_id returns None.""" + event = SimpleNamespace(request_id="", request_type=None, response_type=None, data=None) + assert _request_payload_from_request_event(event) is None + + +class TestCoerceJsonValue: + """Tests for _coerce_json_value helper.""" + + def test_empty_string(self): + """Empty string returns original value.""" + assert _coerce_json_value("") == "" + + def test_whitespace_string(self): + """Whitespace-only string returns original value.""" + assert _coerce_json_value(" ") == " " + + def test_valid_json_parsed(self): + """Valid JSON string is parsed.""" + assert _coerce_json_value('{"a": 1}') == {"a": 1} + + def test_invalid_json_returned_as_is(self): + """Invalid JSON string returned as-is.""" + assert _coerce_json_value("not json") == "not json" + + def test_non_string_returned_as_is(self): + """Non-string values returned as-is.""" + assert _coerce_json_value(42) == 42 + assert _coerce_json_value(None) is None + + +class TestCoerceContent: + """Tests for _coerce_content helper.""" + + def test_already_content(self): + """Content object returned as-is.""" + content = Content.from_text(text="hello") + assert _coerce_content(content) is content + + def test_non_dict_returns_none(self): + """Non-dict value (after JSON parse) returns None.""" + assert _coerce_content([1, 2, 3]) is None + assert _coerce_content(42) is None + + def test_auto_function_approval_response_type_attempted(self): + """Dict with approved+id+function_call triggers the auto-type detection path.""" + # The function injects type="function_approval_response" into a copy, + # but Content.from_dict may fail for complex nested types - returns None. + value = { + "approved": True, + "id": "a1", + "function_call": {"call_id": "c1", "name": "fn", "arguments": "{}"}, + } + # Exercises the auto-detection code path even though result is None + result = _coerce_content(value) + assert result is None # from_dict fails for this shape + + def test_valid_text_content_dict(self): + """Dict with type=text converts successfully.""" + result = _coerce_content({"type": "text", "text": "hello"}) + assert result is not None + assert result.type == "text" + assert result.text == "hello" + + +class TestCoerceMessageContent: + """Tests for _coerce_message_content helper.""" + + def test_string_content(self): + """String content creates text Content.""" + result = _coerce_message_content("hello") + assert result is not None + assert result.type == "text" + assert result.text == "hello" + + def test_already_content_object(self): + """Content object returned as-is.""" + content = Content.from_text(text="test") + assert _coerce_message_content(content) is content + + def test_none_input_returns_none(self): + """None input returns None.""" + assert _coerce_message_content(None) is None + + +class TestCoerceMessage: + """Tests for _coerce_message helper.""" + + def test_already_message(self): + """Message object returned as-is.""" + msg = Message(role="user", contents=[Content.from_text(text="hi")]) + assert _coerce_message(msg) is msg + + def test_non_dict_non_str_returns_none(self): + """Non-dict/str (e.g. int) returns None.""" + assert _coerce_message(123) is None + + def test_empty_contents(self): + """Dict with no contents key gets empty text content.""" + msg = _coerce_message({"role": "user"}) + assert msg is not None + assert len(msg.contents) == 1 + assert msg.contents[0].text == "" + + def test_dict_with_content_key_variant(self): + """'content' key maps to contents.""" + msg = _coerce_message({"role": "assistant", "content": "Done"}) + assert msg is not None + assert msg.role == "assistant" + assert len(msg.contents) == 1 + + +class TestCoerceResponseForRequest: + """Tests for _coerce_response_for_request helper.""" + + def test_response_type_none(self): + """None response_type returns candidate as-is.""" + event = SimpleNamespace(response_type=None) + assert _coerce_response_for_request(event, "hello") == "hello" + + def test_response_type_any(self): + """Any response_type returns candidate as-is.""" + event = SimpleNamespace(response_type=Any) + assert _coerce_response_for_request(event, {"a": 1}) == {"a": 1} + + def test_list_coercion_bare_list(self): + """list without type args passes through.""" + event = SimpleNamespace(response_type=list) + assert _coerce_response_for_request(event, [1, 2]) == [1, 2] + + def test_list_content_coercion(self): + """list[Content] coerces dicts to Content objects.""" + event = SimpleNamespace(response_type=list[Content]) + result = _coerce_response_for_request(event, [{"type": "text", "text": "hi"}]) + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], Content) + + def test_list_message_coercion(self): + """list[Message] coerces dicts to Message objects.""" + event = SimpleNamespace(response_type=list[Message]) + result = _coerce_response_for_request(event, [{"role": "user", "contents": [{"type": "text", "text": "hi"}]}]) + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], Message) + + def test_list_coercion_fails_returns_none(self): + """list coercion returns None when items can't be converted.""" + event = SimpleNamespace(response_type=list[Content]) + result = _coerce_response_for_request(event, [None]) + assert result is None + + def test_str_coercion_from_dict(self): + """str type coerces dict to JSON string.""" + event = SimpleNamespace(response_type=str) + result = _coerce_response_for_request(event, {"a": 1}) + assert isinstance(result, str) + assert '"a"' in result + + def test_unknown_type_mismatch(self): + """Custom class type returns None for non-instance.""" + + class Custom: + pass + + event = SimpleNamespace(response_type=Custom) + assert _coerce_response_for_request(event, "not_custom") is None + + def test_unknown_type_match(self): + """Custom class type returns object if isinstance matches.""" + + class Custom: + pass + + obj = Custom() + event = SimpleNamespace(response_type=Custom) + assert _coerce_response_for_request(event, obj) is obj + + +class TestSinglePendingResponseFromValue: + """Tests for _single_pending_response_from_value helper.""" + + def test_missing_request_id(self): + """Event with no request_id returns empty dict.""" + event = SimpleNamespace(response_type=str) + pending = {"key": event} + result = _single_pending_response_from_value(pending, "value") + assert result == {} + + def test_multiple_pending_returns_empty(self): + """Multiple pending events returns empty dict (ambiguous).""" + e1 = SimpleNamespace(request_id="r1", response_type=str) + e2 = SimpleNamespace(request_id="r2", response_type=str) + result = _single_pending_response_from_value({"r1": e1, "r2": e2}, "val") + assert result == {} + + +class TestCoerceResponsesForPendingRequests: + """Tests for _coerce_responses_for_pending_requests helper.""" + + def test_failed_coercion_skipped(self): + """Incompatible type causes response to be skipped.""" + event = SimpleNamespace(response_type=bool) + responses = {"r1": "not_a_bool"} + pending = {"r1": event} + result = _coerce_responses_for_pending_requests(responses, pending) + assert "r1" not in result + + def test_unknown_request_id_preserved(self): + """Responses for unknown request IDs are preserved as-is.""" + responses = {"unknown_id": "value"} + pending = {} + result = _coerce_responses_for_pending_requests(responses, pending) + assert result == {"unknown_id": "value"} + + def test_empty_responses(self): + """Empty responses dict returns responses unchanged.""" + result = _coerce_responses_for_pending_requests({}, {"r1": SimpleNamespace()}) + assert result == {} + + +class TestMessageRoleValue: + """Tests for _message_role_value helper.""" + + def test_string_role(self): + """String role returned directly.""" + msg = Message(role="user", contents=[]) + assert _message_role_value(msg) == "user" + + def test_enum_role(self): + """Enum-like role gets .value.""" + + class Role(Enum): + USER = "user" + + msg = SimpleNamespace(role=Role.USER) + assert _message_role_value(cast(Any, msg)) == "user" + + +class TestLatestUserText: + """Tests for _latest_user_text helper.""" + + def test_only_assistant_messages(self): + """Only assistant messages returns None.""" + messages = [Message(role="assistant", contents=[Content.from_text(text="hi")])] + assert _latest_user_text(messages) is None + + def test_user_with_non_text_content(self): + """User message with only non-text content returns None.""" + messages = [ + Message(role="user", contents=[Content.from_function_call(call_id="c1", name="fn", arguments="{}")]) + ] + assert _latest_user_text(messages) is None + + def test_user_with_empty_text(self): + """User message with empty/whitespace text returns None.""" + messages = [Message(role="user", contents=[Content.from_text(text=" ")])] + assert _latest_user_text(messages) is None + + +class TestLatestAssistantContents: + """Tests for _latest_assistant_contents helper.""" + + def test_no_assistant_messages(self): + """Only user messages returns None.""" + messages = [Message(role="user", contents=[Content.from_text(text="hi")])] + assert _latest_assistant_contents(messages) is None + + def test_assistant_with_empty_contents(self): + """Assistant message with empty contents returns None.""" + messages = [Message(role="assistant", contents=[])] + assert _latest_assistant_contents(messages) is None + + +class TestTextFromContents: + """Tests for _text_from_contents helper.""" + + def test_empty_text_skipped(self): + """Empty string text content is skipped.""" + contents = [Content.from_text(text="")] + assert _text_from_contents(contents) is None + + def test_non_text_content_skipped(self): + """Non-text content types are skipped.""" + contents = [Content.from_function_call(call_id="c1", name="fn", arguments="{}")] + assert _text_from_contents(contents) is None + + +class TestWorkflowInterruptEventValue: + """Tests for _workflow_interrupt_event_value helper.""" + + def test_none_data(self): + """None data returns None.""" + assert _workflow_interrupt_event_value({"data": None}) is None + + def test_string_data(self): + """String data returned directly.""" + assert _workflow_interrupt_event_value({"data": "text"}) == "text" + + def test_dict_data_serialized(self): + """Dict data is JSON-serialized.""" + result = _workflow_interrupt_event_value({"data": {"key": "val"}}) + assert json.loads(result) == {"key": "val"} + + +class TestWorkflowPayloadToContents: + """Tests for _workflow_payload_to_contents helper.""" + + def test_none_payload(self): + """None payload returns None.""" + assert _workflow_payload_to_contents(None) is None + + def test_non_assistant_message(self): + """User Message returns None.""" + msg = Message(role="user", contents=[Content.from_text(text="hi")]) + assert _workflow_payload_to_contents(msg) is None + + def test_agent_response_update_non_assistant(self): + """AgentResponseUpdate with user role returns None.""" + update = AgentResponseUpdate(contents=[Content.from_text(text="hi")], role="user") + assert _workflow_payload_to_contents(update) is None + + def test_agent_response_update_none_role(self): + """AgentResponseUpdate with None role returns None.""" + update = AgentResponseUpdate(contents=[Content.from_text(text="hi")], role=None) + assert _workflow_payload_to_contents(update) is None + + def test_list_with_none_item(self): + """List containing None causes None return.""" + result = _workflow_payload_to_contents([Content.from_text(text="hi"), None]) + assert result is None + + def test_empty_list(self): + """Empty list returns None.""" + assert _workflow_payload_to_contents([]) is None + + def test_string_payload(self): + """String payload creates text content.""" + result = _workflow_payload_to_contents("hello") + assert result is not None + assert len(result) == 1 + assert result[0].type == "text" + + def test_content_payload(self): + """Single Content returned as list.""" + content = Content.from_text(text="test") + result = _workflow_payload_to_contents(content) + assert result == [content] + + def test_unknown_type_returns_none(self): + """Unknown types return None.""" + assert _workflow_payload_to_contents(42) is None + + +class TestCustomEventValue: + """Tests for _custom_event_value helper.""" + + def test_event_with_data(self): + """Event with .data attribute returns data.""" + event = SimpleNamespace(type="custom", data={"progress": 50}) + assert _custom_event_value(event) == {"progress": 50} + + def test_event_without_data(self): + """Event without .data returns filtered custom fields.""" + event = SimpleNamespace(type="custom", data=None, custom_field="value") + result = _custom_event_value(event) + assert result == {"custom_field": "value"} + + def test_event_with_no_custom_fields(self): + """Event with only base fields returns None.""" + event = SimpleNamespace(type="custom", data=None) + result = _custom_event_value(event) + assert result is None + + +class TestDetailsMessage: + """Tests for _details_message helper.""" + + def test_none_details(self): + """None details returns default message.""" + assert _details_message(None) == "Workflow execution failed." + + def test_details_with_message(self): + """Details with .message attribute uses it.""" + details = SimpleNamespace(message="Custom error") + assert _details_message(details) == "Custom error" + + def test_details_with_empty_message(self): + """Details with empty .message falls back to str().""" + details = SimpleNamespace(message="") + result = _details_message(details) + assert "message=" in result or result == str(details) + + def test_details_without_message(self): + """Details without .message uses str().""" + assert _details_message("plain string") == "plain string" + + +class TestDetailsCode: + """Tests for _details_code helper.""" + + def test_none_details(self): + """None details returns None.""" + assert _details_code(None) is None + + def test_details_with_error_type(self): + """Details with .error_type returns it.""" + details = SimpleNamespace(error_type="ValueError") + assert _details_code(details) == "ValueError" + + def test_details_with_empty_error_type(self): + """Details with empty .error_type returns None.""" + details = SimpleNamespace(error_type="") + assert _details_code(details) is None + + def test_details_without_error_type(self): + """Details without .error_type returns None.""" + details = SimpleNamespace(message="err") + assert _details_code(details) is None + + +# ── Stream integration tests ── + + +async def test_workflow_run_available_interrupts_logged(): + """available_interrupts in input data should be logged without errors.""" + + @executor(id="noop") + async def noop(message: Any, ctx: WorkflowContext) -> None: + pass + + workflow = WorkflowBuilder(start_executor=noop).build() + input_data = { + "messages": [{"role": "user", "content": "go"}], + "available_interrupts": [{"id": "req_1", "type": "request_info"}], + } + + events = [event async for event in run_workflow_stream(input_data, workflow)] + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + assert "RUN_ERROR" not in event_types + + +async def test_workflow_run_failed_event(): + """Workflow 'failed' event should produce RUN_ERROR.""" + + class FailingWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace( + type="failed", details=SimpleNamespace(message="it broke", error_type="TestError") + ) + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, FailingWorkflow()) + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_ERROR" in event_types + error_event = next(e for e in events if e.type == "RUN_ERROR") + assert error_event.message == "it broke" + assert error_event.code == "TestError" + + +async def test_workflow_run_status_enum_state(): + """Status events with enum-like state should be handled.""" + + class WorkflowState(Enum): + IDLE = "idle" + + class StatusWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace(type="status", state=WorkflowState.IDLE) + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, StatusWorkflow()) + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + + +async def test_workflow_run_executor_invoked_drains_text(): + """executor_invoked should drain any open text message.""" + + class ExecutorWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace(type="output", data="Hello world") + yield SimpleNamespace(type="executor_invoked", executor_id="agent_1", data=None) + yield SimpleNamespace(type="executor_completed", executor_id="agent_1", data=None) + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, ExecutorWorkflow()) + ) + ] + + # Text should end before executor step starts + text_end_idx = next(i for i, e in enumerate(events) if e.type == "TEXT_MESSAGE_END") + step_start_idx = next(i for i, e in enumerate(events) if e.type == "STEP_STARTED") + assert text_end_idx < step_start_idx + + +async def test_workflow_run_executor_failed_event(): + """executor_failed event should emit activity snapshot with failed status.""" + + class ExecutorFailWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace( + type="executor_failed", + executor_id="agent_1", + details=SimpleNamespace(message="agent crashed"), + ) + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, ExecutorFailWorkflow()) + ) + ] + + activity = [e for e in events if e.type == "ACTIVITY_SNAPSHOT"] + assert len(activity) == 1 + assert activity[0].content["status"] == "failed" + assert activity[0].content["details"]["message"] == "agent crashed" + + +async def test_workflow_run_list_base_event_output(): + """Workflow yielding list of BaseEvent objects should emit each.""" + + class ListEventWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace( + type="output", + data=[ + StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot={"a": 1}), + StateSnapshotEvent(type=EventType.STATE_SNAPSHOT, snapshot={"b": 2}), + ], + ) + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, ListEventWorkflow()) + ) + ] + + snapshots = [e for e in events if e.type == "STATE_SNAPSHOT"] + assert len(snapshots) == 2 + assert snapshots[0].snapshot == {"a": 1} + assert snapshots[1].snapshot == {"b": 2} + + +async def test_workflow_run_late_run_started(): + """If no events emitted, RUN_STARTED still emitted at end.""" + + class EmptyWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + return + yield # pragma: no cover + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, EmptyWorkflow()) + ) + ] + + assert events[0].type == "RUN_STARTED" + assert events[-1].type == "RUN_FINISHED" + + +async def test_workflow_run_last_assistant_text_update(): + """Text outputs update last_assistant_text for dedup tracking.""" + + class DualTextWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace(type="output", data="First text") + yield SimpleNamespace(type="output", data="Second text") + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, DualTextWorkflow()) + ) + ] + + text_deltas = [e.delta for e in events if e.type == "TEXT_MESSAGE_CONTENT"] + assert "First text" in text_deltas + assert "Second text" in text_deltas + + +async def test_workflow_run_superstep_events(): + """superstep_started/completed emit Step events with iteration.""" + + class SuperstepWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace(type="superstep_started", iteration=1) + yield SimpleNamespace(type="superstep_completed", iteration=1) + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, SuperstepWorkflow()) + ) + ] + + step_started = [e for e in events if e.type == "STEP_STARTED"] + step_finished = [e for e in events if e.type == "STEP_FINISHED"] + assert len(step_started) == 1 + assert step_started[0].step_name == "superstep:1" + assert len(step_finished) == 1 + assert step_finished[0].step_name == "superstep:1" + + +async def test_workflow_run_non_terminal_status_emits_custom(): + """Non-terminal status events emit custom events.""" + + class StatusWorkflow: + def run(self, **kwargs: Any): + async def _stream(): + yield SimpleNamespace(type="started") + yield SimpleNamespace(type="status", state="running") + + return _stream() + + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "go"}]}, cast(Any, StatusWorkflow()) + ) + ] + + custom = [e for e in events if e.type == "CUSTOM" and e.name == "status"] + assert len(custom) == 1 + assert custom[0].value == {"state": "running"}