diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py index 17b927326b..b887d86df3 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_magentic.py @@ -14,6 +14,7 @@ from agent_framework import ( AgentResponse, + AgentSession, Message, SupportsAgentRun, ) @@ -559,6 +560,7 @@ def __init__( ) self._agent: SupportsAgentRun = agent + self._session: AgentSession = self._agent.create_session() self.task_ledger: _MagenticTaskLedger | None = task_ledger # Prompts may be overridden if needed @@ -587,7 +589,7 @@ async def _complete( The agent's run method is called which applies the agent's configured options (temperature, seed, instructions, etc.). """ - response: AgentResponse = await self._agent.run(messages) + response: AgentResponse = await self._agent.run(messages, session=self._session) if not response.messages: raise RuntimeError("Agent returned no messages in response.") if len(response.messages) > 1: @@ -730,6 +732,7 @@ def on_checkpoint_save(self) -> dict[str, Any]: state: dict[str, Any] = {} if self.task_ledger is not None: state["task_ledger"] = self.task_ledger.to_dict() + state["agent_session"] = self._session.to_dict() return state @override @@ -740,6 +743,12 @@ def on_checkpoint_restore(self, state: dict[str, Any]) -> None: self.task_ledger = _MagenticTaskLedger.from_dict(ledger) except Exception: # pragma: no cover - defensive logger.warning("Failed to restore manager task ledger from checkpoint state") + session_payload = state.get("agent_session") + if session_payload is not None: + try: + self._session = AgentSession.from_dict(session_payload) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore manager agent session from checkpoint state") # endregion Magentic Manager diff --git a/python/packages/orchestrations/tests/test_magentic.py b/python/packages/orchestrations/tests/test_magentic.py index e1d0ef8c32..1857a16ee4 100644 --- a/python/packages/orchestrations/tests/test_magentic.py +++ b/python/packages/orchestrations/tests/test_magentic.py @@ -1074,4 +1074,71 @@ def agent_factory() -> SupportsAgentRun: assert manager.final_answer_prompt == custom_final_prompt +async def test_standard_manager_propagates_session_to_agent(): + """Verify StandardMagenticManager passes a consistent session to the underlying agent. + + Regression test for #4371: context providers (e.g. RedisHistoryProvider) configured on + the manager agent silently failed because no session was propagated. + """ + captured_sessions: list[AgentSession | None] = [] + + class SessionCapturingAgent(BaseAgent): + """Agent that records the session passed to each run() call.""" + + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + *, + stream: bool = False, + session: Any = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse] | AsyncIterable[AgentResponseUpdate]: + captured_sessions.append(session) + + async def _run() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", ["ok"])]) + + return _run() + + agent = SessionCapturingAgent() + mgr = StandardMagenticManager(agent=agent) + ctx = MagenticContext(task="task", participant_descriptions={"a": "desc"}) + + await mgr.plan(ctx.clone()) + + # plan() calls _complete twice (facts + plan), both should receive the same session + assert len(captured_sessions) == 2 + assert all(s is not None for s in captured_sessions), "session must be passed to agent.run()" + assert captured_sessions[0] is captured_sessions[1], "same session instance must be reused across calls" + assert captured_sessions[0] is mgr._session + + +def test_standard_manager_checkpoint_preserves_session(): + """Verify that checkpoint save/restore preserves the manager's session identity.""" + agent = StubManagerAgent() + mgr = StandardMagenticManager(agent=agent) + original_session_id = mgr._session.session_id + + state = mgr.on_checkpoint_save() + assert "agent_session" in state + + # Restore into a fresh manager and verify session_id is preserved + mgr2 = StandardMagenticManager(agent=agent) + assert mgr2._session.session_id != original_session_id + mgr2.on_checkpoint_restore(state) + assert mgr2._session.session_id == original_session_id + + +def test_standard_manager_checkpoint_restore_empty_state(): + """Verify that restoring from a state without agent_session leaves the session intact.""" + agent = StubManagerAgent() + mgr = StandardMagenticManager(agent=agent) + original_session = mgr._session + original_session_id = original_session.session_id + + mgr.on_checkpoint_restore({}) + assert mgr._session is original_session + assert mgr._session.session_id == original_session_id + + # endregion