diff --git a/.env.example b/.env.example index 3641a38..734cd69 100644 --- a/.env.example +++ b/.env.example @@ -1,9 +1,14 @@ # Required Environment Variables +# Either OPENAI_API_KEY or ANTHROPIC_API_KEY must be set OPENAI_API_KEY="sk-*****" +# ANTHROPIC_API_KEY="sk-ant-*****" -# Optional: Model selection (defaults to gpt-4o-mini if not set) +# Optional: Model selection (defaults based on provider) +# For OpenAI (default: gpt-4o-mini) # OPENAI_MODEL_NAME="gpt-4o-mini" # OPENAI_MODEL_NAME="glm-4.5-air" # z.ai +# For Anthropic (default: claude-3-5-sonnet-20241022) +# ANTHROPIC_MODEL_NAME="claude-haiku-4-5-20251001" # Optional: Use OpenAI-compatible APIs (uncomment one) # OPENAI_BASE_URL="https://api.z.ai/api/coding/paas/v4" # z.ai diff --git a/agentic-framework/src/agentic_framework/constants.py b/agentic-framework/src/agentic_framework/constants.py index 5a5461e..8fe202c 100644 --- a/agentic-framework/src/agentic_framework/constants.py +++ b/agentic-framework/src/agentic_framework/constants.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from typing import Literal from dotenv import load_dotenv @@ -8,4 +9,39 @@ BASE_DIR = Path(__file__).resolve().parent.parent.parent LOGS_DIR = BASE_DIR / "logs" -DEFAULT_MODEL = os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini") +Provider = Literal["openai", "anthropic"] + + +def detect_provider() -> Provider: + """Detect which LLM provider to use based on available API keys. + + Returns: + "anthropic" if ANTHROPIC_API_KEY is set, "openai" otherwise. + + Note: + This defaults to OpenAI since it's the most commonly available, + but Anthropic is preferred if both keys are available. + """ + if os.getenv("ANTHROPIC_API_KEY"): + return "anthropic" + return "openai" + + +def get_default_model() -> str: + """Get the default model name based on available provider. + + Returns: + Default model name for the detected provider. + + Examples: + - Anthropic: "claude-3-5-sonnet-20241022" + - OpenAI: "gpt-4o-mini" + """ + provider = detect_provider() + if provider == "anthropic": + return os.getenv("ANTHROPIC_MODEL_NAME", "claude-haiku-4-5-20251001") + return os.getenv("OPENAI_MODEL_NAME", "gpt-4o-mini") + + +# Legacy constant for backward compatibility +DEFAULT_MODEL = get_default_model() diff --git a/agentic-framework/src/agentic_framework/core/langgraph_agent.py b/agentic-framework/src/agentic_framework/core/langgraph_agent.py index d3e07d6..2e57cfd 100644 --- a/agentic-framework/src/agentic_framework/core/langgraph_agent.py +++ b/agentic-framework/src/agentic_framework/core/langgraph_agent.py @@ -2,28 +2,47 @@ from typing import Any, Dict, List, Sequence, Union from langchain.agents import create_agent +from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage, HumanMessage from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import InMemorySaver -from agentic_framework.constants import DEFAULT_MODEL +from agentic_framework.constants import detect_provider, get_default_model from agentic_framework.interfaces.base import Agent from agentic_framework.mcp import MCPProvider +def _create_model(model_name: str, temperature: float): # type: ignore[no-any-return] + """Create the appropriate LLM model instance based on detected provider. + + Args: + model_name: Name of the model to use. + temperature: Temperature setting for the model. + + Returns: + Either ChatAnthropic or ChatOpenAI instance. + """ + provider = detect_provider() + if provider == "anthropic": + return ChatAnthropic(model=model_name, temperature=temperature) # type: ignore[call-arg] + return ChatOpenAI(model=model_name, temperature=temperature) + + class LangGraphMCPAgent(Agent): """Reusable base class for LangGraph agents with optional MCP tools.""" def __init__( self, - model_name: str = DEFAULT_MODEL, + model_name: str | None = None, temperature: float = 0.1, mcp_provider: MCPProvider | None = None, initial_mcp_tools: List[Any] | None = None, thread_id: str = "1", **kwargs: Any, ): - self.model = ChatOpenAI(model=model_name, temperature=temperature) + if model_name is None: + model_name = get_default_model() + self.model = _create_model(model_name, temperature) self._mcp_provider = mcp_provider self._initial_mcp_tools = initial_mcp_tools self._thread_id = thread_id diff --git a/agentic-framework/src/agentic_framework/core/simple_agent.py b/agentic-framework/src/agentic_framework/core/simple_agent.py index 6717b43..694b95f 100644 --- a/agentic-framework/src/agentic_framework/core/simple_agent.py +++ b/agentic-framework/src/agentic_framework/core/simple_agent.py @@ -1,14 +1,31 @@ from typing import Any, Dict, List, Union +from langchain_anthropic import ChatAnthropic from langchain_core.messages import BaseMessage from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI -from agentic_framework.constants import DEFAULT_MODEL +from agentic_framework.constants import detect_provider, get_default_model from agentic_framework.interfaces.base import Agent from agentic_framework.registry import AgentRegistry +def _create_model(model_name: str, temperature: float): # type: ignore[no-any-return] + """Create the appropriate LLM model instance based on detected provider. + + Args: + model_name: Name of the model to use. + temperature: Temperature setting for the model. + + Returns: + Either ChatAnthropic or ChatOpenAI instance. + """ + provider = detect_provider() + if provider == "anthropic": + return ChatAnthropic(model=model_name, temperature=temperature) # type: ignore[call-arg] + return ChatOpenAI(model=model_name, temperature=temperature) + + @AgentRegistry.register("simple", mcp_servers=None) class SimpleAgent(Agent): """ @@ -16,8 +33,10 @@ class SimpleAgent(Agent): No MCP access (mcp_servers=None in registry). """ - def __init__(self, model_name: str = DEFAULT_MODEL, temperature: float = 0.0, **kwargs: Any) -> None: - self.model = ChatOpenAI(model=model_name, temperature=temperature) + def __init__(self, model_name: str | None = None, temperature: float = 0.0, **kwargs: Any) -> None: + if model_name is None: + model_name = get_default_model() + self.model = _create_model(model_name, temperature) self.prompt = ChatPromptTemplate.from_messages( [("system", "You are a helpful assistant."), ("user", "{input}")] ) diff --git a/agentic-framework/tests/test_agent.py b/agentic-framework/tests/test_agent.py index 558757a..2e0243f 100644 --- a/agentic-framework/tests/test_agent.py +++ b/agentic-framework/tests/test_agent.py @@ -9,16 +9,19 @@ def test_simple_agent_initialization(): - with patch("agentic_framework.core.simple_agent.ChatOpenAI") as MockChatOpenAI: - # Configure the mock - MockChatOpenAI.return_value + with patch("agentic_framework.core.simple_agent._create_model") as MockCreateModel: + # Configure the mock - return a callable that behaves like a Runnable + def fake_model(*args, **kwargs): + return SimpleNamespace() + + MockCreateModel.return_value = fake_model agent = SimpleAgent(model_name="gpt-4o-mini") assert agent is not None assert agent.get_tools() == [] - # Verify ChatOpenAI was initialized with correct params - MockChatOpenAI.assert_called_once_with(model="gpt-4o-mini", temperature=0.0) + # Verify _create_model was called with correct params + MockCreateModel.assert_called_once_with("gpt-4o-mini", 0.0) def test_simple_agent_run_with_string(monkeypatch): @@ -31,7 +34,15 @@ class FakePrompt: def __or__(self, model): return FakeChain() - monkeypatch.setattr("agentic_framework.core.simple_agent.ChatOpenAI", lambda **kwargs: object()) + def fake_model(*args, **kwargs): + return FakeModel() + + class FakeModel: + def __or__(self, other): + # Return a chain when model is combined with prompt + return FakeChain() + + monkeypatch.setattr("agentic_framework.core.simple_agent._create_model", fake_model) monkeypatch.setattr( "agentic_framework.core.simple_agent.ChatPromptTemplate", SimpleNamespace(from_messages=lambda messages: FakePrompt()), @@ -52,7 +63,14 @@ async def ainvoke(self, payload): return FakeChain() - monkeypatch.setattr("agentic_framework.core.simple_agent.ChatOpenAI", lambda **kwargs: object()) + class FakeModel: + def __or__(self, other): + return FakePrompt() + + def fake_model(*args, **kwargs): + return FakeModel() + + monkeypatch.setattr("agentic_framework.core.simple_agent._create_model", fake_model) monkeypatch.setattr( "agentic_framework.core.simple_agent.ChatPromptTemplate", SimpleNamespace(from_messages=lambda messages: FakePrompt()), diff --git a/agentic-framework/tests/test_langgraph_agent.py b/agentic-framework/tests/test_langgraph_agent.py index 139d9ca..734d6f0 100644 --- a/agentic-framework/tests/test_langgraph_agent.py +++ b/agentic-framework/tests/test_langgraph_agent.py @@ -34,11 +34,25 @@ async def ainvoke(self, payload, config): return {"messages": [SimpleNamespace(content="ok")]} +class DummyModel: + """A fake model that can be combined with other runnables.""" + + def __init__(self): + pass + + def __or__(self, other): + # Return a DummyGraph when combined with create_agent + return DummyGraph() + + def test_langgraph_agent_initializes_with_local_and_initial_mcp_tools(monkeypatch): graph = DummyGraph() captured = {} - monkeypatch.setattr("agentic_framework.core.langgraph_agent.ChatOpenAI", lambda **kwargs: object()) + def fake_model(*args, **kwargs): + return DummyModel() + + monkeypatch.setattr("agentic_framework.core.langgraph_agent._create_model", fake_model) def fake_create_agent(**kwargs): captured.update(kwargs) @@ -61,7 +75,10 @@ def test_langgraph_agent_uses_provider_tools_once(monkeypatch): provider = DummyProvider(["mcp-a", "mcp-b"]) captured = {} - monkeypatch.setattr("agentic_framework.core.langgraph_agent.ChatOpenAI", lambda **kwargs: object()) + def fake_model(*args, **kwargs): + return DummyModel() + + monkeypatch.setattr("agentic_framework.core.langgraph_agent._create_model", fake_model) def fake_create_agent(**kwargs): captured.update(kwargs) @@ -81,7 +98,10 @@ def fake_create_agent(**kwargs): def test_langgraph_agent_run_accepts_message_list_and_custom_config(monkeypatch): graph = DummyGraph() - monkeypatch.setattr("agentic_framework.core.langgraph_agent.ChatOpenAI", lambda **kwargs: object()) + def fake_model(*args, **kwargs): + return DummyModel() + + monkeypatch.setattr("agentic_framework.core.langgraph_agent._create_model", fake_model) monkeypatch.setattr("agentic_framework.core.langgraph_agent.create_agent", lambda **kwargs: graph) agent = DummyAgent(initial_mcp_tools=[])