Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/packages/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ extend = "../../pyproject.toml"

[tool.pyright]
extends = "../../pyproject.toml"
include = ["tests/workflow"]

[tool.mypy]
plugins = ['pydantic.mypy']
Expand Down
119 changes: 95 additions & 24 deletions python/packages/core/tests/workflow/test_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

import logging
from collections.abc import AsyncIterable, Awaitable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, overload

import pytest

from agent_framework import (
AgentExecutor,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Content,
Message,
ResponseStream,
WorkflowEvent,
WorkflowRunState,
)
from agent_framework._workflows._agent_executor import AgentExecutorResponse
Expand All @@ -32,26 +33,56 @@ def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.call_count = 0

@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
self.call_count += 1
if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")]
contents=[
Content.from_text(
text=f"Response #{self.call_count}: {self.name}"
)
]
)

return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])])
return AgentResponse(
messages=[
Message("assistant", [f"Response #{self.call_count}: {self.name}"])
]
)

return _run()

Expand All @@ -63,13 +94,36 @@ def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.result_hook_called = False

@overload
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
if stream:

async def _stream() -> AsyncIterable[AgentResponseUpdate]:
Expand All @@ -78,21 +132,25 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]:
role="assistant",
)

async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
async def _mark_result_hook_called(
response: AgentResponse,
) -> AgentResponse:
self.result_hook_called = True
return response

return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
_mark_result_hook_called
)
return ResponseStream(
_stream(), finalizer=AgentResponse.from_updates
).with_result_hook(_mark_result_hook_called)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["hook test"])])

return _run()


async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> (
None
):
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
Expand Down Expand Up @@ -159,7 +217,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:

executor_state = executor_states[executor.id] # type: ignore[index]
assert "cache" in executor_state, "Checkpoint should store executor cache state"
assert "agent_session" in executor_state, "Checkpoint should store executor session state"
assert "agent_session" in executor_state, (
"Checkpoint should store executor session state"
)

# Verify session state structure
session_state = executor_state["agent_session"] # type: ignore[index]
Expand All @@ -180,11 +240,15 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert restored_agent.call_count == 0

# Build new workflow with the restored executor
wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build()
wf_resume = SequentialBuilder(
participants=[restored_executor], checkpoint_storage=storage
).build()

# Resume from checkpoint
resumed_output: AgentExecutorResponse | None = None
async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True):
async for ev in wf_resume.run(
checkpoint_id=restore_checkpoint.checkpoint_id, stream=True
):
if ev.type == "output":
resumed_output = ev.data # type: ignore[assignment]
if ev.type == "status" and ev.state in (
Expand Down Expand Up @@ -278,7 +342,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
workflow = SequentialBuilder(participants=[executor]).build()

# stream=True at workflow level triggers streaming mode (returns async iterable)
events = []
events: list[WorkflowEvent] = []
async for event in workflow.run("hello", stream=True):
events.append(event)
assert len(events) > 0
Expand All @@ -290,10 +354,13 @@ async def test_prepare_agent_run_args_strips_reserved_kwargs(
reserved_kwarg: str, caplog: "LogCaptureFixture"
) -> None:
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"}
raw: dict[str, Any] = {
reserved_kwarg: "should-be-stripped",
"custom_key": "keep-me",
}

with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]

assert reserved_kwarg not in run_kwargs
assert "custom_key" in run_kwargs
Expand All @@ -304,8 +371,8 @@ async def test_prepare_agent_run_args_strips_reserved_kwargs(

async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
"""Non-reserved workflow kwargs should pass through unchanged."""
raw = {"custom_param": "value", "another": 42}
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
raw: dict[str, Any] = {"custom_param": "value", "another": 42}
run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert run_kwargs["custom_param"] == "value"
assert run_kwargs["another"] == 42

Expand All @@ -314,10 +381,10 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
caplog: "LogCaptureFixture",
) -> None:
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
raw = {"session": "x", "stream": True, "messages": [], "custom": 1}
raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1}

with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]

assert "session" not in run_kwargs
assert "stream" not in run_kwargs
Expand All @@ -326,7 +393,11 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
assert options is not None
assert options["additional_function_arguments"]["custom"] == 1

warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
warned_keys = {
r.message.split("'")[1]
for r in caplog.records
if "reserved" in r.message.lower()
}
assert warned_keys == {"session", "stream", "messages"}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Tests for AgentExecutor handling of tool calls and results in streaming mode."""

from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from typing import Any
from typing import Any, Literal, overload

from typing_extensions import Never

Expand All @@ -13,6 +13,7 @@
AgentExecutorResponse,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
ChatResponse,
Expand All @@ -37,18 +38,38 @@ class _ToolCallingAgent(BaseAgent):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...

@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
return ResponseStream(self._run_stream_impl(), finalizer=AgentResponse.from_updates)

async def _run() -> AgentResponse:
async def _run() -> AgentResponse[Any]:
return AgentResponse(messages=[Message("assistant", ["done"])])

return _run()
Expand Down Expand Up @@ -111,6 +132,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
# First event: text update
assert events[0].data is not None
assert events[0].data.contents[0].type == "text"
assert events[0].data.contents[0].text is not None
assert "Let me search" in events[0].data.contents[0].text

# Second event: function call
Expand All @@ -129,6 +151,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
# Fourth event: final text
assert events[3].data is not None
assert events[3].data.contents[0].type == "text"
assert events[3].data.contents[0].text is not None
assert "sunny" in events[3].data.contents[0].text


Expand Down
45 changes: 14 additions & 31 deletions python/packages/core/tests/workflow/test_agent_utils.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,33 @@
# Copyright (c) Microsoft. All rights reserved.

from collections.abc import AsyncIterable
from typing import Any
from collections.abc import Awaitable
from typing import Any, Literal, overload

from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Message
from agent_framework import AgentResponse, AgentResponseUpdate, AgentRunInputs, AgentSession, ResponseStream
from agent_framework._workflows._agent_utils import resolve_agent_id


class MockAgent:
"""Mock agent for testing agent utilities."""

def __init__(self, agent_id: str, name: str | None = None) -> None:
self._id = agent_id
self._name = name
self.id: str = agent_id
self.name: str | None = name
self.description: str | None = None

@property
def id(self) -> str:
return self._id

@property
def name(self) -> str | None:
return self._name

@property
def display_name(self) -> str:
"""Returns the display name of the agent."""
...

@property
def description(self) -> str | None:
"""Returns the description of the agent."""
...

def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

def create_session(self, **kwargs: Any) -> AgentSession:
"""Creates a new conversation session for the agent."""
...

def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()


def test_resolve_agent_id_with_name() -> None:
"""Test that resolve_agent_id returns name when agent has a name."""
Expand Down
Loading