Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/rai_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "rai_core"
version = "2.6.0"
version = "2.7.0"
description = "Core functionality for RAI framework"
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
readme = "README.md"
Expand Down
64 changes: 51 additions & 13 deletions src/rai_core/rai/agents/langchain/core/megamind.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,27 +209,64 @@ def get_initial_megamind_state(task: str):
)


@dataclass
class PlanPrompts:
"""Configurable prompts for the planning step."""

objective_template: str = "You are given objective to complete: {original_task}"
steps_done_header: str = "Steps that were already done successfully:\n"
next_step_prompt: str = "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent."
first_step_prompt: str = (
"\nCome up with the first step and delegate it to selected agent."
)
completion_prompt: str = (
"\n\nWhen you decide that the objective is completed return response to user."
)

@classmethod
def default(cls):
"""Return default prompts."""
return cls()

@classmethod
def custom(cls, **kwargs):
"""Create custom prompts with overrides."""
default = cls.default()
for key, value in kwargs.items():
if hasattr(default, key):
setattr(default, key, value)
return default


def plan_step(
megamind_agent: BaseChatModel,
state: MegamindState,
prompts: Optional[PlanPrompts] = None,
context_providers: Optional[List[ContextProvider]] = None,
) -> MegamindState:
"""Initial planning step."""
if prompts is None:
prompts = PlanPrompts.default()

if "original_task" not in state:
state["original_task"] = state["messages"][0].content[0]["text"]
if "steps_done" not in state:
state["steps_done"] = []
if "step" not in state:
state["step"] = None

megamind_prompt = f"You are given objective to complete: {state['original_task']}"
megamind_prompt = prompts.objective_template.format(
original_task=state["original_task"]
)
if context_providers:
for provider in context_providers:
megamind_prompt += provider.get_context()
megamind_prompt += "\n"

# Add completed steps if any
if state["steps_done"]:
megamind_prompt += "\n\n"
megamind_prompt += "Steps that were already done successfully:\n"
megamind_prompt += prompts.steps_done_header
steps_done = "\n".join(
[f"{i + 1}. {step}" for i, step in enumerate(state["steps_done"])]
)
Expand All @@ -239,22 +276,17 @@ def plan_step(
if state["step"]:
if not state["step_success"]:
raise ValueError("Step success should be specified at this point")

megamind_prompt += "\nBased on that outcome and past steps come up with the next step and delegate it to selected agent."
megamind_prompt += prompts.next_step_prompt

else:
megamind_prompt += "\n"
megamind_prompt += (
"Come up with the fist step and delegate it to selected agent."
)
megamind_prompt += prompts.first_step_prompt

megamind_prompt += prompts.completion_prompt

megamind_prompt += "\n\n"
megamind_prompt += (
"When you decide that the objective is completed return response to user."
)
messages = [
HumanMultimodalMessage(content=megamind_prompt),
]

# NOTE (jmatejcz) the response of megamind isnt appended to messages
# as Command from handoff instantly transitions to next node
megamind_agent.invoke({"messages": messages})
Expand All @@ -266,6 +298,7 @@ def create_megamind(
executors: List[Executor],
megamind_system_prompt: Optional[str] = None,
task_planning_prompt: Optional[str] = None,
plan_prompts: Optional[PlanPrompts] = None,
context_providers: List[ContextProvider] = [],
) -> CompiledStateGraph:
"""Create a megamind langchain agent
Expand Down Expand Up @@ -325,7 +358,12 @@ def create_megamind(

graph = StateGraph(MegamindState).add_node(
"megamind",
partial(plan_step, megamind_agent, context_providers=context_providers),
partial(
plan_step,
megamind_agent,
context_providers=context_providers,
prompts=plan_prompts,
),
)
for agent_name, agent in executor_agents.items():
graph.add_node(agent_name, agent)
Expand Down
228 changes: 228 additions & 0 deletions tests/agents/langchain/test_megamind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

import pytest
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.tools import tool
from rai.agents.langchain.core.megamind import (
ContextProvider,
Executor,
MegamindState,
PlanPrompts,
StepSuccess,
analyzer_node,
create_megamind,
create_react_structured_agent,
get_initial_megamind_state,
llm_node,
plan_step,
should_continue_or_structure,
)
from rai.messages import HumanMultimodalMessage


@pytest.fixture
def mock_llm():
llm = MagicMock()
llm.bind_tools.return_value = llm
return llm


@pytest.fixture
def mock_tool():
@tool
def sample_tool(query: str):
"""Smple tool"""
return f"Tool called with {query}"

return sample_tool


@pytest.fixture(scope="function")
def default_state():
state: MegamindState = {
"original_task": "task",
"steps_done": [],
"step": "current step",
"step_success": StepSuccess(success=False, explanation=""),
"step_messages": [HumanMessage(content="hello")],
"messages": [],
}
return state


class MockProvider(ContextProvider):
def get_context(self) -> str:
return "Extra context"


def test_llm_node(mock_llm, default_state):
mock_response = AIMessage(content="response")
mock_llm.invoke.return_value = mock_response

new_state = llm_node(mock_llm, "system prompt", default_state)

assert len(new_state["step_messages"]) == 2
assert new_state["step_messages"][-1] == mock_response

# Check if invoke was called with correct messages (including system prompt)
args, _ = mock_llm.invoke.call_args
messages = args[0]
assert isinstance(messages[0], SystemMessage)
assert messages[0].content == "system prompt"
assert isinstance(messages[1], HumanMessage)
assert messages[1].content == "current step"


def test_analyzer_node(mock_llm, default_state):
mock_analysis = StepSuccess(success=True, explanation="Great success")

# Mock with_structured_output to return a mock analyzer
mock_analyzer = MagicMock()
mock_analyzer.invoke.return_value = mock_analysis
mock_llm.with_structured_output.return_value = mock_analyzer

new_state = analyzer_node(mock_llm, "plan prompt", default_state)

assert new_state["step_success"] == mock_analysis
assert len(new_state["steps_done"]) == 1
assert "Great success" in new_state["steps_done"][0]


def test_should_continue_or_structure(default_state):
# no tool calls
assert should_continue_or_structure(default_state) == "structured_output"

# with tool calls
AIMessage(content="text", tool_calls=[{"name": "tool", "args": {}, "id": "1"}])
default_state["step_messages"] = [
AIMessage(content="text", tool_calls=[{"name": "tool", "args": {}, "id": "1"}])
]
assert should_continue_or_structure(default_state) == "tools"


# def test_create_handoff_tool():
# handoff = create_handoff_tool("agent_x", "desc")
# assert handoff.name == "transfer_to_agent_x"
# assert handoff.description == "desc"

# # Verify the tool logic (returns a Command)
# result = handoff.invoke({"task_instruction": "do subtask"})
# assert isinstance(result, Command)
# assert result.goto == "agent_x"
# assert result.update["step"] == "do subtask"
# assert result.graph == Command.PARENT


def test_plan_step(mock_llm, default_state):
default_state.update(
{
"messages": [HumanMultimodalMessage(content="task")],
"steps_done": ["step 1 done"],
"step": "step 1",
"step_success": StepSuccess(success=True, explanation="ok"),
"step_messages": [],
}
)

mock_llm.invoke.return_value = AIMessage(content="ignored")

plan_step(mock_llm, default_state)

# Check if invoke was called with prompt containing history
call_args = mock_llm.invoke.call_args
messages = call_args[0][0]["messages"]
prompt_text = messages[0].content
assert "StepSuccess" not in str(type(prompt_text))
assert mock_llm.invoke.called


def test_plan_prompts():
defaults = PlanPrompts.default()
assert "objective" in defaults.objective_template

custom = PlanPrompts.custom(objective_template="Custom {original_task}")
assert custom.objective_template == "Custom {original_task}"
assert custom.steps_done_header == defaults.steps_done_header


def test_plan_step_edge_cases(mock_llm, default_state):
default_state["messages"] = [HumanMultimodalMessage(content="task")]
default_state["step"] = None # To trigger "first_step_prompt" path
mock_llm.invoke.return_value = AIMessage(content="ok")

plan_step(
mock_llm,
default_state,
prompts=None,
context_providers=[MockProvider()],
)

args, _ = mock_llm.invoke.call_args
prompt_content = args[0]["messages"][0].content[0]["text"]
assert "Extra context" in prompt_content

# step defined but success missing
default_state["step"] = "some step"
default_state["step_success"] = None
with pytest.raises(ValueError, match="Step success should be specified"):
plan_step(mock_llm, default_state)

# missing keys initialization
empty_state: MegamindState = {"messages": [HumanMultimodalMessage(content="Start")]}
plan_step(mock_llm, empty_state)
assert empty_state["original_task"] == "Start"
assert empty_state["steps_done"] == []
assert empty_state["step"] is None


def test_get_initial_megamind_state():
state = get_initial_megamind_state("My Task")
assert state["original_task"] == "My Task"
assert len(state["messages"]) == 1
assert state["messages"][0].content == [{"text": "My Task", "type": "text"}]
assert state["steps_done"] == []


def test_create_react_structured_agent_no_tools(mock_llm):
# Test creation without tools
graph = create_react_structured_agent(mock_llm, tools=[])
assert graph is not None


def test_llm_node_error(mock_llm, default_state):
# Test missing step raises ValueError
default_state["step"] = None
with pytest.raises(ValueError, match="Step should be defined"):
llm_node(mock_llm, "sys", default_state)


def test_create_megamind(mock_llm, mock_tool):
executor = Executor(
name="specialist", llm=mock_llm, tools=[mock_tool], system_prompt="sys prompt"
)

graph = create_megamind(
megamind_llm=mock_llm,
executors=[executor],
megamind_system_prompt="mega prompt",
)

assert graph is not None
Loading