Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/packages/ag-ui/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ packages = ["agent_framework_ag_ui", "agent_framework_ag_ui_examples"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests/ag_ui"]
pythonpath = ["."]
pythonpath = [".", "tests/ag_ui"]
markers = [
"integration: marks tests as integration tests that require external services",
]
Expand Down
88 changes: 88 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sys
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Mapping, MutableSequence, Sequence
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Generic, Literal, cast, overload

Expand Down Expand Up @@ -36,6 +37,13 @@
ResponseFn = Callable[..., Awaitable[ChatResponse]]


def pytest_configure() -> None:
"""Ensure this test directory is on sys.path so helper modules can be imported by name."""
test_dir = str(Path(__file__).resolve().parent)
if test_dir not in sys.path:
sys.path.insert(0, test_dir)


class StreamingChatClientStub(
ChatMiddlewareLayer[OptionsCoT],
FunctionInvocationLayer[OptionsCoT],
Expand Down Expand Up @@ -241,3 +249,83 @@ def stream_from_updates_fixture() -> Callable[[list[ChatResponseUpdate]], Stream
def stub_agent() -> type[SupportsAgentRun]:
"""Return the StubAgent class for creating test instances."""
return StubAgent # type: ignore[return-value]


# ── Fixtures for golden / integration tests ──


@pytest.fixture
def collect_events() -> Callable[..., Any]:
"""Return an async helper that collects all events from an async generator."""

async def _collect(async_gen: AsyncIterable[Any]) -> list[Any]:
return [event async for event in async_gen]

return _collect


@pytest.fixture
def make_agent_wrapper() -> Callable[..., Any]:
"""Factory that builds an AgentFrameworkAgent from a stream function.

Usage::

agent = make_agent_wrapper(
stream_fn=stream_from_updates(updates),
state_schema=...,
)
events = [e async for e in agent.run(payload)]
"""
from agent_framework_ag_ui import AgentFrameworkAgent

def _factory(
stream_fn: StreamFn,
*,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
) -> Any:
client = StreamingChatClientStub(stream_fn)
stub = StubAgent(client=client)
return AgentFrameworkAgent(
agent=stub,
state_schema=state_schema,
predict_state_config=predict_state_config,
require_confirmation=require_confirmation,
)

return _factory


@pytest.fixture
def make_app() -> Callable[..., Any]:
"""Factory that builds a FastAPI app with an AG-UI endpoint.

Usage::

app = make_app(agent_or_wrapper, path="/test")
"""
from fastapi import FastAPI

from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint

def _factory(
agent: Any,
*,
path: str = "/",
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
default_state: dict[str, Any] | None = None,
) -> FastAPI:
app = FastAPI()
add_agent_framework_fastapi_endpoint(
app,
agent,
path=path,
state_schema=state_schema,
predict_state_config=predict_state_config,
default_state=default_state,
)
return app

return _factory
175 changes: 175 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/event_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) Microsoft. All rights reserved.

"""EventStream assertion helper for AG-UI regression tests."""

from __future__ import annotations

from typing import Any


class EventStream:
"""Wraps a list of AG-UI events with structured assertion methods.

Usage:
events = [event async for event in agent.run(payload)]
stream = EventStream(events)
stream.assert_bookends()
stream.assert_text_messages_balanced()
"""

def __init__(self, events: list[Any]) -> None:
self.events = events

def __len__(self) -> int:
return len(self.events)

def __iter__(self):
return iter(self.events)

def types(self) -> list[str]:
"""Return ordered list of event type strings."""
return [self._type_str(e) for e in self.events]

def get(self, event_type: str) -> list[Any]:
"""Filter events matching the given type string."""
return [e for e in self.events if self._type_str(e) == event_type]

def first(self, event_type: str) -> Any:
"""Return the first event matching the given type, or raise."""
matches = self.get(event_type)
if not matches:
raise ValueError(f"No event of type {event_type!r} found. Available: {self.types()}")
return matches[0]

def last(self, event_type: str) -> Any:
"""Return the last event matching the given type, or raise."""
matches = self.get(event_type)
if not matches:
raise ValueError(f"No event of type {event_type!r} found. Available: {self.types()}")
return matches[-1]

def snapshot(self) -> dict[str, Any]:
"""Return the latest StateSnapshotEvent snapshot dict."""
return self.last("STATE_SNAPSHOT").snapshot

def messages_snapshot(self) -> list[Any]:
"""Return the latest MessagesSnapshotEvent messages list."""
return self.last("MESSAGES_SNAPSHOT").messages

# ── Structural assertions ──

def assert_bookends(self) -> None:
"""Assert first event is RUN_STARTED and last is RUN_FINISHED."""
types = self.types()
assert types, "Event stream is empty"
assert types[0] == "RUN_STARTED", f"Expected RUN_STARTED first, got {types[0]}"
assert types[-1] == "RUN_FINISHED", f"Expected RUN_FINISHED last, got {types[-1]}"

def assert_has_run_lifecycle(self) -> None:
"""Assert RUN_STARTED is first and RUN_FINISHED exists (may not be last).

Use this instead of assert_bookends() for workflow resume streams where
_drain_open_message() can emit TEXT_MESSAGE_END after RUN_FINISHED.
"""
types = self.types()
assert types, "Event stream is empty"
assert types[0] == "RUN_STARTED", f"Expected RUN_STARTED first, got {types[0]}"
assert "RUN_FINISHED" in types, f"Expected RUN_FINISHED in stream. Types: {types}"

def assert_strict_types(self, expected: list[str]) -> None:
"""Assert exact type sequence match."""
actual = self.types()
assert actual == expected, f"Event type mismatch.\nExpected: {expected}\nActual: {actual}"

def assert_ordered_types(self, expected: list[str]) -> None:
"""Assert expected types appear as a subsequence (in order, not necessarily contiguous)."""
actual = self.types()
actual_idx = 0
for expected_type in expected:
found = False
while actual_idx < len(actual):
if actual[actual_idx] == expected_type:
actual_idx += 1
found = True
break
actual_idx += 1
if not found:
raise AssertionError(
f"Expected subsequence type {expected_type!r} not found after index {actual_idx}.\n"
f"Expected subsequence: {expected}\n"
f"Actual types: {actual}"
)

def assert_text_messages_balanced(self) -> None:
"""Assert every TEXT_MESSAGE_START has a matching TEXT_MESSAGE_END with the same message_id."""
starts: dict[str, int] = {}
ends: set[str] = set()
for i, event in enumerate(self.events):
t = self._type_str(event)
if t == "TEXT_MESSAGE_START":
mid = event.message_id
assert mid not in starts, f"Duplicate TEXT_MESSAGE_START for message_id={mid}"
starts[mid] = i
elif t == "TEXT_MESSAGE_END":
mid = event.message_id
assert mid in starts, f"TEXT_MESSAGE_END for unknown message_id={mid}"
assert mid not in ends, f"Duplicate TEXT_MESSAGE_END for message_id={mid}"
ends.add(mid)

unclosed = set(starts.keys()) - ends
assert not unclosed, f"Unclosed text messages: {unclosed}"

def assert_tool_calls_balanced(self) -> None:
"""Assert every TOOL_CALL_START has a matching TOOL_CALL_END with the same tool_call_id."""
starts: dict[str, int] = {}
ends: set[str] = set()
for i, event in enumerate(self.events):
t = self._type_str(event)
if t == "TOOL_CALL_START":
tid = event.tool_call_id
assert tid not in starts, f"Duplicate TOOL_CALL_START for tool_call_id={tid}"
starts[tid] = i
elif t == "TOOL_CALL_END":
tid = event.tool_call_id
assert tid in starts, f"TOOL_CALL_END for unknown tool_call_id={tid}"
assert tid not in ends, f"Duplicate TOOL_CALL_END for tool_call_id={tid}"
ends.add(tid)

unclosed = set(starts.keys()) - ends
assert not unclosed, f"Unclosed tool calls: {unclosed}"

def assert_no_run_error(self) -> None:
"""Assert no RUN_ERROR events exist."""
errors = self.get("RUN_ERROR")
if errors:
messages = [getattr(e, "message", str(e)) for e in errors]
raise AssertionError(f"Found {len(errors)} RUN_ERROR event(s): {messages}")

def assert_has_type(self, event_type: str) -> None:
"""Assert at least one event of the given type exists."""
assert event_type in self.types(), f"Expected {event_type!r} in stream. Available: {self.types()}"

def assert_message_ids_consistent(self) -> None:
"""Assert TEXT_MESSAGE_CONTENT events reference valid, open message_ids."""
open_messages: set[str] = set()
for event in self.events:
t = self._type_str(event)
if t == "TEXT_MESSAGE_START":
open_messages.add(event.message_id)
elif t == "TEXT_MESSAGE_END":
open_messages.discard(event.message_id)
elif t == "TEXT_MESSAGE_CONTENT":
mid = event.message_id
assert mid in open_messages, f"TEXT_MESSAGE_CONTENT references message_id={mid} which is not open"

# ── Internal ──

@staticmethod
def _type_str(event: Any) -> str:
"""Extract event type as a plain string."""
t = getattr(event, "type", None)
if t is None:
return type(event).__name__
if isinstance(t, str):
return t
return getattr(t, "value", str(t))
1 change: 1 addition & 0 deletions python/packages/ag-ui/tests/ag_ui/golden/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
13 changes: 13 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/golden/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.

"""Conftest for golden tests — ensures parent test dir is importable."""

import sys
from pathlib import Path


def pytest_configure() -> None:
"""Ensure parent test directory is on sys.path for helper module imports."""
parent_test_dir = str(Path(__file__).resolve().parent.parent)
if parent_test_dir not in sys.path:
sys.path.insert(0, parent_test_dir)
Loading