Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from agent_framework import (
AgentResponse,
AgentSession,
Message,
SupportsAgentRun,
)
Expand Down Expand Up @@ -559,6 +560,7 @@ def __init__(
)

self._agent: SupportsAgentRun = agent
self._session: AgentSession = self._agent.create_session()
self.task_ledger: _MagenticTaskLedger | None = task_ledger

# Prompts may be overridden if needed
Expand Down Expand Up @@ -587,7 +589,7 @@ async def _complete(
The agent's run method is called which applies the agent's configured options
(temperature, seed, instructions, etc.).
"""
response: AgentResponse = await self._agent.run(messages)
response: AgentResponse = await self._agent.run(messages, session=self._session)
if not response.messages:
raise RuntimeError("Agent returned no messages in response.")
if len(response.messages) > 1:
Expand Down Expand Up @@ -730,6 +732,7 @@ def on_checkpoint_save(self) -> dict[str, Any]:
state: dict[str, Any] = {}
if self.task_ledger is not None:
state["task_ledger"] = self.task_ledger.to_dict()
state["agent_session"] = self._session.to_dict()
return state

@override
Expand All @@ -740,6 +743,12 @@ def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
self.task_ledger = _MagenticTaskLedger.from_dict(ledger)
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager task ledger from checkpoint state")
session_payload = state.get("agent_session")
if session_payload is not None:
try:
self._session = AgentSession.from_dict(session_payload)
except Exception: # pragma: no cover - defensive
logger.warning("Failed to restore manager agent session from checkpoint state")


# endregion Magentic Manager
Expand Down
67 changes: 67 additions & 0 deletions python/packages/orchestrations/tests/test_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,4 +1074,71 @@ def agent_factory() -> SupportsAgentRun:
assert manager.final_answer_prompt == custom_final_prompt


async def test_standard_manager_propagates_session_to_agent():
"""Verify StandardMagenticManager passes a consistent session to the underlying agent.

Regression test for #4371: context providers (e.g. RedisHistoryProvider) configured on
the manager agent silently failed because no session was propagated.
"""
captured_sessions: list[AgentSession | None] = []

class SessionCapturingAgent(BaseAgent):
"""Agent that records the session passed to each run() call."""

def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: Any = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]:
captured_sessions.append(session)

async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["ok"])])

return _run()

agent = SessionCapturingAgent()
mgr = StandardMagenticManager(agent=agent)
ctx = MagenticContext(task="task", participant_descriptions={"a": "desc"})

await mgr.plan(ctx.clone())

# plan() calls _complete twice (facts + plan), both should receive the same session
assert len(captured_sessions) == 2
assert all(s is not None for s in captured_sessions), "session must be passed to agent.run()"
assert captured_sessions[0] is captured_sessions[1], "same session instance must be reused across calls"
assert captured_sessions[0] is mgr._session


def test_standard_manager_checkpoint_preserves_session():
"""Verify that checkpoint save/restore preserves the manager's session identity."""
agent = StubManagerAgent()
mgr = StandardMagenticManager(agent=agent)
original_session_id = mgr._session.session_id

state = mgr.on_checkpoint_save()
assert "agent_session" in state

# Restore into a fresh manager and verify session_id is preserved
mgr2 = StandardMagenticManager(agent=agent)
assert mgr2._session.session_id != original_session_id
mgr2.on_checkpoint_restore(state)
assert mgr2._session.session_id == original_session_id


def test_standard_manager_checkpoint_restore_empty_state():
"""Verify that restoring from a state without agent_session leaves the session intact."""
agent = StubManagerAgent()
mgr = StandardMagenticManager(agent=agent)
original_session = mgr._session
original_session_id = original_session.session_id

mgr.on_checkpoint_restore({})
assert mgr._session is original_session
assert mgr._session.session_id == original_session_id


# endregion