diff --git a/CLAUDE.md b/CLAUDE.md index 39e0ffb..1dd3b63 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -6,4 +6,5 @@ - The project uses `uv`, `ruff` and `mypy` - Run commands should be prefixed with `uv`: `uv run ...` - Use `asyncio` features, if such is needed +- Prefer early returns - Absolutely no useless comments! Every class and method does not need to be documented (unless it is legitimetly complex or "lib-ish") diff --git a/README.md b/README.md index 3717ea8..308394c 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ make chat make dev ``` -Additional MCP servers are configured in `agent-chat-cli.config.yaml` and prompts added within the `prompts` folder. +Additional MCP servers are configured in `agent-chat-cli.config.yaml` and prompts added within the `prompts` folder. By default, MCP servers are loaded dynamically via inference; set `mcp_server_inference: false` to load all servers at startup. ## Development diff --git a/agent-chat-cli.config.yaml b/agent-chat-cli.config.yaml index 2072ccd..2f114b7 100644 --- a/agent-chat-cli.config.yaml +++ b/agent-chat-cli.config.yaml @@ -9,6 +9,9 @@ model: haiku # Enable streaming responses include_partial_messages: true +# Enable dynamic MCP server inference +mcp_server_inference: true + # Named agents with custom configurations # agents: # sample_agent: diff --git a/src/agent_chat_cli/app.py b/src/agent_chat_cli/app.py index aafefb5..d6703bd 100644 --- a/src/agent_chat_cli/app.py +++ b/src/agent_chat_cli/app.py @@ -8,10 +8,10 @@ from agent_chat_cli.components.chat_history import ChatHistory, MessagePosted from agent_chat_cli.components.thinking_indicator import ThinkingIndicator from agent_chat_cli.components.user_input import UserInput -from agent_chat_cli.utils import AgentLoop -from agent_chat_cli.utils.message_bus import MessageBus +from agent_chat_cli.system.agent_loop import AgentLoop +from agent_chat_cli.system.message_bus import MessageBus +from agent_chat_cli.system.actions import Actions from agent_chat_cli.utils.logger import setup_logging -from agent_chat_cli.utils.actions import Actions from dotenv import load_dotenv @@ -20,7 +20,7 @@ class AgentChatCLIApp(App): - CSS_PATH = "utils/styles.tcss" + CSS_PATH = "system/styles.tcss" BINDINGS = [ Binding("ctrl+c", "quit", "Quit", show=False, priority=True), diff --git a/src/agent_chat_cli/components/user_input.py b/src/agent_chat_cli/components/user_input.py index 49deb15..3259607 100644 --- a/src/agent_chat_cli/components/user_input.py +++ b/src/agent_chat_cli/components/user_input.py @@ -9,7 +9,7 @@ from agent_chat_cli.components.chat_history import MessagePosted from agent_chat_cli.components.thinking_indicator import ThinkingIndicator from agent_chat_cli.components.messages import Message -from agent_chat_cli.utils.actions import Actions +from agent_chat_cli.system.actions import Actions from agent_chat_cli.utils.enums import ControlCommand diff --git a/src/agent_chat_cli/docs/architecture.md b/src/agent_chat_cli/docs/architecture.md index 60bc234..132a47a 100644 --- a/src/agent_chat_cli/docs/architecture.md +++ b/src/agent_chat_cli/docs/architecture.md @@ -16,24 +16,34 @@ Textual widgets responsible for UI rendering: - **UserInput**: Handles user text input and submission - **ThinkingIndicator**: Shows when agent is processing -### Utils Layer +### System Layer -#### Agent Loop (`agent_loop.py`) +#### Agent Loop (`system/agent_loop.py`) Manages the conversation loop with Claude SDK: - Maintains async queue for user queries - Handles streaming responses - Parses SDK messages into structured AgentMessage objects - Emits AgentMessageType events (STREAM_EVENT, ASSISTANT, RESULT) - Manages session persistence via session_id +- Supports dynamic MCP server inference and loading + +#### MCP Server Inference (`system/mcp_inference.py`) +Intelligently determines which MCP servers are needed for each query: +- Uses a persistent Haiku client for fast inference (~1-3s after initial boot) +- Analyzes user queries to infer required servers +- Maintains a cached set of inferred servers across conversation +- Returns only newly needed servers to minimize reconnections +- Can be disabled via `mcp_server_inference: false` config option -#### Message Bus (`message_bus.py`) +#### Message Bus (`system/message_bus.py`) Routes agent messages to appropriate UI components: - Handles streaming text updates - Mounts tool use messages - Controls thinking indicator state - Manages scroll-to-bottom behavior +- Displays system messages (e.g., MCP server connection notifications) -#### Actions (`actions.py`) +#### Actions (`system/actions.py`) Centralizes all user-initiated actions and controls: - **quit()**: Exits the application - **query(user_input)**: Sends user query to agent loop queue @@ -46,15 +56,20 @@ Actions are triggered via: - Keybindings in app.py (ESC → action_interrupt, Ctrl+N → action_new) - Text commands in user_input.py ("exit", "clear") -#### Config (`config.py`) +### Utils Layer + +#### Config (`utils/config.py`) Loads and validates YAML configuration: - Filters disabled MCP servers - Loads prompts from files - Expands environment variables - Combines system prompt with MCP server prompts +- Provides `get_sdk_config()` to filter app-specific config before passing to SDK ## Data Flow +### Standard Query Flow (with MCP Inference enabled) + ``` User Input ↓ @@ -64,7 +79,16 @@ MessagePosted event → ChatHistory (immediate UI update) ↓ Actions.query(user_input) → AgentLoop.query_queue.put() ↓ -Claude SDK (streaming response) +AgentLoop: MCP Server Inference (if enabled) + ↓ +infer_mcp_servers(user_message) → Haiku query + ↓ +If new servers needed: + - Post SYSTEM message ("Connecting to [servers]...") + - Disconnect client + - Reconnect with new servers (preserving session_id) + ↓ +Claude SDK (streaming response with connected MCP tools) ↓ AgentLoop._handle_message ↓ @@ -73,9 +97,26 @@ AgentMessage (typed message) → MessageBus.handle_agent_message Match on AgentMessageType: - STREAM_EVENT → Update streaming message widget - ASSISTANT → Mount tool use widgets + - SYSTEM → Display system notification - RESULT → Reset thinking indicator ``` +### Query Flow (with MCP Inference disabled) + +``` +User Input + ↓ +UserInput.on_input_submitted + ↓ +MessagePosted event → ChatHistory (immediate UI update) + ↓ +Actions.query(user_input) → AgentLoop.query_queue.put() + ↓ +Claude SDK (all servers pre-connected at startup) + ↓ +[Same as above from _handle_message onwards] +``` + ### Control Commands Flow ``` User Action (ESC, Ctrl+N, "clear", "exit") @@ -138,7 +179,10 @@ class Message: Configuration is loaded from `agent-chat-cli.config.yaml`: - **system_prompt**: Base system prompt (supports file paths) - **model**: Claude model to use -- **include_partial_messages**: Enable streaming +- **include_partial_messages**: Enable streaming responses (default: true) +- **mcp_server_inference**: Enable dynamic MCP server inference (default: true) + - When `true`: App boots instantly without MCP servers, connects only when needed + - When `false`: All enabled MCP servers load at startup (traditional behavior) - **mcp_servers**: MCP server configurations (filtered by enabled flag) - **agents**: Named agent configurations - **disallowed_tools**: Tool filtering @@ -146,6 +190,27 @@ Configuration is loaded from `agent-chat-cli.config.yaml`: MCP server prompts are automatically appended to the system prompt. +### MCP Server Inference + +When `mcp_server_inference: true` (default): + +1. **Fast Boot**: App starts without connecting to any MCP servers +2. **Smart Detection**: Before each query, Haiku analyzes which servers are needed +3. **Dynamic Loading**: Only connects to newly required servers +4. **Session Preservation**: Maintains conversation history when reconnecting with new servers +5. **Performance**: ~1-3s inference latency after initial boot (first query ~8-12s) + +Example config: +```yaml +mcp_server_inference: true # or false to disable + +mcp_servers: + github: + description: "Search code, PRs, issues" + enabled: true + # ... rest of config +``` + ## User Commands ### Text Commands diff --git a/src/agent_chat_cli/utils/actions.py b/src/agent_chat_cli/system/actions.py similarity index 95% rename from src/agent_chat_cli/utils/actions.py rename to src/agent_chat_cli/system/actions.py index 322b8c2..dea5526 100644 --- a/src/agent_chat_cli/utils/actions.py +++ b/src/agent_chat_cli/system/actions.py @@ -1,4 +1,4 @@ -from agent_chat_cli.utils.agent_loop import AgentLoop +from agent_chat_cli.system.agent_loop import AgentLoop from agent_chat_cli.utils.enums import ControlCommand from agent_chat_cli.components.chat_history import ChatHistory from agent_chat_cli.components.thinking_indicator import ThinkingIndicator diff --git a/src/agent_chat_cli/utils/agent_loop.py b/src/agent_chat_cli/system/agent_loop.py similarity index 58% rename from src/agent_chat_cli/utils/agent_loop.py rename to src/agent_chat_cli/system/agent_loop.py index e4e947e..823e313 100644 --- a/src/agent_chat_cli/utils/agent_loop.py +++ b/src/agent_chat_cli/system/agent_loop.py @@ -13,8 +13,14 @@ ToolUseBlock, ) -from agent_chat_cli.utils.config import load_config +from agent_chat_cli.utils.config import ( + load_config, + get_available_servers, + get_sdk_config, +) from agent_chat_cli.utils.enums import AgentMessageType, ContentType, ControlCommand +from agent_chat_cli.system.mcp_inference import infer_mcp_servers +from agent_chat_cli.utils.logger import log_json @dataclass @@ -31,12 +37,10 @@ def __init__( ) -> None: self.config = load_config() self.session_id = session_id + self.available_servers = get_available_servers() + self.inferred_servers: set[str] = set() - config_dict = self.config.model_dump() - if session_id: - config_dict["resume"] = session_id - - self.client = ClaudeSDKClient(options=ClaudeAgentOptions(**config_dict)) + self.client: ClaudeSDKClient self.on_message = on_message self.query_queue: asyncio.Queue[str | ControlCommand] = asyncio.Queue() @@ -44,9 +48,28 @@ def __init__( self._running = False self.interrupting = False - async def start(self) -> None: + async def _initialize_client(self, mcp_servers: dict) -> None: + sdk_config = get_sdk_config(self.config) + sdk_config["mcp_servers"] = mcp_servers + + if self.session_id: + sdk_config["resume"] = self.session_id + + self.client = ClaudeSDKClient(options=ClaudeAgentOptions(**sdk_config)) + await self.client.connect() + async def start(self) -> None: + if self.config.mcp_server_inference: + await self._initialize_client(mcp_servers={}) + else: + mcp_servers = { + name: config.model_dump() + for name, config in self.available_servers.items() + } + + await self._initialize_client(mcp_servers=mcp_servers) + self._running = True while self._running: @@ -54,11 +77,53 @@ async def start(self) -> None: if isinstance(user_input, ControlCommand): if user_input == ControlCommand.NEW_CONVERSATION: + self.inferred_servers.clear() + await self.client.disconnect() - await self.client.connect() + + if self.config.mcp_server_inference: + await self._initialize_client(mcp_servers={}) + else: + mcp_servers = { + name: config.model_dump() + for name, config in self.available_servers.items() + } + + await self._initialize_client(mcp_servers=mcp_servers) continue + if self.config.mcp_server_inference: + inference_result = await infer_mcp_servers( + user_message=user_input, + available_servers=self.available_servers, + inferred_servers=self.inferred_servers, + session_id=self.session_id, + ) + + if inference_result["new_servers"]: + server_list = ", ".join(inference_result["new_servers"]) + + await self.on_message( + AgentMessage( + type=AgentMessageType.SYSTEM, + data=f"Connecting to {server_list}...", + ) + ) + + await asyncio.sleep(0.1) + + await self.client.disconnect() + + mcp_servers = { + name: config.model_dump() + for name, config in inference_result["selected_servers"].items() + } + + await self._initialize_client(mcp_servers=mcp_servers) + self.interrupting = False + + # Send query await self.client.query(user_input) async for message in self.client.receive_response(): @@ -71,6 +136,8 @@ async def start(self) -> None: async def _handle_message(self, message: Any) -> None: if isinstance(message, SystemMessage): + log_json(message.data) + if message.subtype == AgentMessageType.INIT.value and message.data.get( "session_id" ): diff --git a/src/agent_chat_cli/system/mcp_inference.py b/src/agent_chat_cli/system/mcp_inference.py new file mode 100644 index 0000000..eaa9f95 --- /dev/null +++ b/src/agent_chat_cli/system/mcp_inference.py @@ -0,0 +1,106 @@ +from textwrap import dedent +from typing import Any + +from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions +from claude_agent_sdk.types import ResultMessage + +from agent_chat_cli.utils.config import MCPServerConfig + +_inference_client: ClaudeSDKClient | None = None + + +async def _get_inference_client( + available_servers: dict[str, MCPServerConfig], +) -> ClaudeSDKClient: + global _inference_client + + if _inference_client is not None: + return _inference_client + + server_descriptions = "\n".join( + [ + f"- {name}: {config.description}" + for name, config in available_servers.items() + ] + ) + + system_prompt = dedent( + f""" + You are an MCP server inference engine. Based on the user's message, determine which MCP servers are needed to fulfill the request. + + Available MCP servers: + {server_descriptions} + + Return ONLY the names of servers that are likely needed for this request. If no specific servers are needed, return an empty array. + + Examples: + - "Show me my GitHub issues" → ["github"] + - "Open a browser tab" → ["chrome"] + - "What's the weather?" → [] + - "Search my Notion workspace and open related GitHub PRs" → ["notion", "github"] + """ + ).strip() + + inference_options = ClaudeAgentOptions( + model="haiku", + output_format={ + "type": "json_schema", + "schema": { + "type": "object", + "properties": { + "servers": { + "type": "array", + "items": {"type": "string"}, + "description": "List of MCP server names to connect to", + } + }, + "required": ["servers"], + }, + }, + system_prompt=system_prompt, + mcp_servers={}, + ) + + _inference_client = ClaudeSDKClient(options=inference_options) + + await _inference_client.connect() + + return _inference_client + + +async def infer_mcp_servers( + user_message: str, + available_servers: dict[str, MCPServerConfig], + inferred_servers: set[str], + session_id: str | None = None, +) -> dict[str, Any]: + if not available_servers: + return {"selected_servers": {}, "new_servers": []} + + client = await _get_inference_client(available_servers) + + selected_server_names: list[str] = [] + + await client.query(user_message) + + async for message in client.receive_response(): + if isinstance(message, ResultMessage): + if hasattr(message, "structured_output") and message.structured_output: + selected_server_names = message.structured_output.get("servers", []) + + new_servers = [ + name for name in selected_server_names if name not in inferred_servers + ] + + inferred_servers.update(selected_server_names) + + selected_servers = { + name: available_servers[name] + for name in selected_server_names + if name in available_servers + } + + return { + "selected_servers": selected_servers, + "new_servers": new_servers, + } diff --git a/src/agent_chat_cli/utils/message_bus.py b/src/agent_chat_cli/system/message_bus.py similarity index 88% rename from src/agent_chat_cli/utils/message_bus.py rename to src/agent_chat_cli/system/message_bus.py index 3e059b0..d7ba305 100644 --- a/src/agent_chat_cli/utils/message_bus.py +++ b/src/agent_chat_cli/system/message_bus.py @@ -9,9 +9,10 @@ from agent_chat_cli.components.user_input import UserInput from agent_chat_cli.components.messages import ( AgentMessage as AgentMessageWidget, + Message, ToolMessage, ) -from agent_chat_cli.utils.agent_loop import AgentMessage +from agent_chat_cli.system.agent_loop import AgentMessage from agent_chat_cli.utils.enums import AgentMessageType, ContentType if TYPE_CHECKING: @@ -32,6 +33,9 @@ async def handle_agent_message(self, message: AgentMessage) -> None: case AgentMessageType.ASSISTANT: await self._handle_assistant(message) + case AgentMessageType.SYSTEM: + await self._handle_system(message) + case AgentMessageType.RESULT: await self._handle_result() @@ -86,6 +90,14 @@ async def _handle_assistant(self, message: AgentMessage) -> None: await self._scroll_to_bottom() + async def _handle_system(self, message: AgentMessage) -> None: + system_content = ( + message.data if isinstance(message.data, str) else str(message.data) + ) + + self.app.post_message(MessagePosted(Message.system(system_content))) + await self._scroll_to_bottom() + async def _handle_result(self) -> None: thinking_indicator = self.app.query_one(ThinkingIndicator) thinking_indicator.is_thinking = False diff --git a/src/agent_chat_cli/utils/styles.tcss b/src/agent_chat_cli/system/styles.tcss similarity index 100% rename from src/agent_chat_cli/utils/styles.tcss rename to src/agent_chat_cli/system/styles.tcss diff --git a/src/agent_chat_cli/utils/__init__.py b/src/agent_chat_cli/utils/__init__.py index 3f9d258..7a26ab8 100644 --- a/src/agent_chat_cli/utils/__init__.py +++ b/src/agent_chat_cli/utils/__init__.py @@ -1,13 +1,10 @@ """Utility modules for Agent Chat CLI.""" -from agent_chat_cli.utils.agent_loop import AgentLoop, AgentMessage from agent_chat_cli.utils.enums import AgentMessageType, ContentType from agent_chat_cli.utils.tool_info import ToolInfo, get_tool_info from agent_chat_cli.utils.format_tool_input import format_tool_input __all__ = [ - "AgentLoop", - "AgentMessage", "AgentMessageType", "ContentType", "ToolInfo", diff --git a/src/agent_chat_cli/utils/config.py b/src/agent_chat_cli/utils/config.py index a58b153..4853174 100644 --- a/src/agent_chat_cli/utils/config.py +++ b/src/agent_chat_cli/utils/config.py @@ -30,6 +30,7 @@ class AgentChatConfig(BaseModel): system_prompt: str model: str include_partial_messages: bool = True # Enable streaming responses + mcp_server_inference: bool = True # Enable dynamic MCP server inference agents: dict[str, AgentConfig] = Field(default_factory=dict) mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict) disallowed_tools: list[str] = Field(default_factory=list) @@ -97,3 +98,16 @@ def load_config( ) return AgentChatConfig(**raw_config) + + +def get_available_servers( + config_path: str | Path = "agent-chat-cli.config.yaml", +) -> dict[str, MCPServerConfig]: + config = load_config(config_path) + return config.mcp_servers + + +def get_sdk_config(config: AgentChatConfig) -> dict: + sdk_config = config.model_dump() + sdk_config.pop("mcp_server_inference", None) + return sdk_config diff --git a/tests/test_mcp_inference.py b/tests/test_mcp_inference.py new file mode 100644 index 0000000..e581c94 --- /dev/null +++ b/tests/test_mcp_inference.py @@ -0,0 +1,60 @@ +import asyncio +import time +from dotenv import load_dotenv + +from agent_chat_cli.utils.config import get_available_servers +from agent_chat_cli.system.mcp_inference import infer_mcp_servers, _inference_client + +load_dotenv() + +# TODO: This can be deleted, but keeping here to check if speed is related to anthropics +# servers or something else + + +async def test_inference(): + """ + Tests the overall return times for MCP inference + To run: uv run python tests/test_mcp_inference.py + """ + + print("=== MCP Server Inference Test ===\n") + + available_servers = get_available_servers() + print(f"Available servers: {list(available_servers.keys())}\n") + + inferred_servers: set[str] = set() + + test_queries = [ + "Show me my GitHub issues", + "Open a browser tab", + "What's the weather?", + "Search my GitHub for code related to authentication", + ] + + for user_message in test_queries: + print(f"Query: {user_message}") + + start_time = time.time() + + result = await infer_mcp_servers( + user_message=user_message, + available_servers=available_servers, + inferred_servers=inferred_servers, + ) + + elapsed = time.time() - start_time + + print(f"Selected servers: {list(result['selected_servers'].keys())}") + print(f"New servers: {result['new_servers']}") + print(f"Time: {elapsed:.2f}s") + print(f"Inferred servers so far: {inferred_servers}") + print("-" * 50 + "\n") + + if _inference_client: + print("Disconnecting inference client...") + await _inference_client.disconnect() + print("Done!") + + +if __name__ == "__main__": + asyncio.run(test_inference())