diff --git a/src/ros2_medkit_mcp/config.py b/src/ros2_medkit_mcp/config.py index 659721f..842c6a6 100644 --- a/src/ros2_medkit_mcp/config.py +++ b/src/ros2_medkit_mcp/config.py @@ -27,7 +27,7 @@ class Settings(BaseModel): description="Base URL of the ros2_medkit SOVD API", ) bearer_token: str | None = Field( - default_factory=lambda: os.getenv("ROS2_MEDKIT_BEARER_TOKEN"), + default_factory=lambda: os.getenv("ROS2_MEDKIT_BEARER_TOKEN") or None, description="Optional Bearer token for authentication", ) timeout_seconds: float = Field( diff --git a/src/ros2_medkit_mcp/mcp_app.py b/src/ros2_medkit_mcp/mcp_app.py index 9c89c05..1d8d5b6 100644 --- a/src/ros2_medkit_mcp/mcp_app.py +++ b/src/ros2_medkit_mcp/mcp_app.py @@ -58,6 +58,7 @@ UpdateExecutionArgs, filter_entities, ) +from ros2_medkit_mcp.plugin import McpPlugin logger = logging.getLogger(__name__) @@ -630,18 +631,23 @@ async def download_rosbags_for_fault( } -def register_tools(server: Server, client: SovdClient) -> None: +def register_tools( + server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None +) -> None: """Register all MCP tools on the server. Args: server: The MCP server to register tools on. client: The SOVD client for making API calls. + plugins: Optional list of plugins providing additional tools. """ + # Tool name → plugin mapping, built during list_tools and used for dispatch + plugin_tool_map: dict[str, McpPlugin] = {} @server.list_tools() async def list_tools() -> list[Tool]: """List available tools.""" - return [ + tools = [ # ==================== Discovery ==================== Tool( name="sovd_version", @@ -1491,6 +1497,31 @@ async def list_tools() -> list[Tool]: }, ), ] + # Append plugin tools + if plugins: + for plugin in plugins: + try: + plugin_tools = plugin.list_tools() + for t in plugin_tools: + if t.name in TOOL_ALIASES: + logger.warning( + "Plugin %s: tool '%s' collides with built-in tool, skipping", + plugin.name, + t.name, + ) + continue + if t.name in plugin_tool_map: + logger.warning( + "Plugin %s: tool '%s' collides with another plugin tool, skipping", + plugin.name, + t.name, + ) + continue + tools.append(t) + plugin_tool_map[t.name] = plugin + except Exception: + logger.exception("Failed to list tools from plugin: %s", plugin.name) + return tools @server.call_tool() async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: @@ -1794,6 +1825,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: ) else: + # Check plugin tool map before reporting unknown tool + plugin = plugin_tool_map.get(normalized_name) + if plugin is not None: + return await plugin.call_tool(normalized_name, arguments) return format_error(f"Unknown tool: {name}") except SovdClientError as e: @@ -1849,15 +1884,18 @@ async def read_resource(uri: str) -> list[TextContent]: raise ValueError(f"Unknown resource URI: {uri}") -def setup_mcp_app(server: Server, settings: Settings, client: SovdClient) -> None: +def setup_mcp_app( + server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None +) -> None: """Set up the complete MCP application. Args: server: The MCP server to configure. settings: Application settings. client: The SOVD client for API calls. + plugins: Optional list of plugins providing additional tools. """ - register_tools(server, client) + register_tools(server, client, plugins=plugins) register_resources(server) logger.info( "MCP server configured for %s", diff --git a/src/ros2_medkit_mcp/plugin.py b/src/ros2_medkit_mcp/plugin.py new file mode 100644 index 0000000..9917350 --- /dev/null +++ b/src/ros2_medkit_mcp/plugin.py @@ -0,0 +1,75 @@ +"""Plugin interface for ros2_medkit_mcp. + +Third-party packages can register as plugins via entry_points: + + [project.entry-points."ros2_medkit_mcp.plugins"] + my_plugin = "my_package.plugin:MyPlugin" + +Plugins must implement the McpPlugin protocol. +""" + +from __future__ import annotations + +import logging +from importlib.metadata import entry_points +from typing import Any, Protocol + +from mcp.types import TextContent, Tool + +logger = logging.getLogger(__name__) + +PLUGIN_GROUP = "ros2_medkit_mcp.plugins" + + +class McpPlugin(Protocol): + """Interface for MCP server plugins.""" + + @property + def name(self) -> str: ... + + def list_tools(self) -> list[Tool]: ... + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> list[TextContent]: ... + + async def startup(self) -> None: ... + + async def shutdown(self) -> None: ... + + +def discover_plugins() -> list[McpPlugin]: + """Discover and instantiate plugins registered via entry_points.""" + plugins: list[McpPlugin] = [] + for ep in entry_points(group=PLUGIN_GROUP): + try: + plugin_cls = ep.load() + plugin = plugin_cls() + if not hasattr(plugin, "name") or not hasattr(plugin, "list_tools"): + logger.warning("Plugin %s does not implement McpPlugin, skipping", ep.name) + continue + logger.info("Discovered plugin: %s (from %s)", plugin.name, ep.value) + plugins.append(plugin) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + return plugins + + +async def start_plugins(plugins: list[McpPlugin]) -> list[McpPlugin]: + """Start plugins, returning only those that started successfully.""" + started: list[McpPlugin] = [] + for plugin in plugins: + try: + await plugin.startup() + started.append(plugin) + logger.info("Plugin started: %s", plugin.name) + except Exception: + logger.exception("Failed to start plugin: %s", plugin.name) + return started + + +async def shutdown_plugins(plugins: list[McpPlugin]) -> None: + """Shut down plugins, logging errors without raising.""" + for plugin in plugins: + try: + await plugin.shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) diff --git a/src/ros2_medkit_mcp/server_http.py b/src/ros2_medkit_mcp/server_http.py index bead326..2691bca 100644 --- a/src/ros2_medkit_mcp/server_http.py +++ b/src/ros2_medkit_mcp/server_http.py @@ -18,6 +18,7 @@ from ros2_medkit_mcp.client import SovdClient from ros2_medkit_mcp.config import get_settings from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app +from ros2_medkit_mcp.plugin import McpPlugin, discover_plugins, shutdown_plugins, start_plugins # Configure logging logging.basicConfig( @@ -36,7 +37,7 @@ def create_app() -> Starlette: settings = get_settings() mcp_server = create_mcp_server() client = SovdClient(settings) - setup_mcp_app(mcp_server, settings, client) + plugins = discover_plugins() # Create SSE transport - path is where clients POST messages sse_transport = SseServerTransport("/mcp/messages/") @@ -77,13 +78,19 @@ async def health_check(_request: Request) -> JSONResponse: } ) + started_plugins: list[McpPlugin] = [] + async def on_startup() -> None: """Application startup handler.""" logger.info("ros2_medkit MCP server starting (HTTP transport)") logger.info("Connecting to SOVD API at %s", settings.base_url) + started = await start_plugins(plugins) + started_plugins.extend(started) + setup_mcp_app(mcp_server, settings, client, plugins=started_plugins) async def on_shutdown() -> None: """Application shutdown handler.""" + await shutdown_plugins(started_plugins) await client.close() logger.info("Server shutdown complete") diff --git a/src/ros2_medkit_mcp/server_stdio.py b/src/ros2_medkit_mcp/server_stdio.py index 1f32e6c..ce183b5 100644 --- a/src/ros2_medkit_mcp/server_stdio.py +++ b/src/ros2_medkit_mcp/server_stdio.py @@ -13,6 +13,7 @@ from ros2_medkit_mcp.client import SovdClient from ros2_medkit_mcp.config import get_settings from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app +from ros2_medkit_mcp.plugin import discover_plugins, shutdown_plugins, start_plugins # Configure logging to stderr to avoid interfering with stdio transport logging.basicConfig( @@ -31,9 +32,11 @@ async def run_server() -> None: server = create_mcp_server() client = SovdClient(settings) + plugins = discover_plugins() + started_plugins = await start_plugins(plugins) try: - setup_mcp_app(server, settings, client) + setup_mcp_app(server, settings, client, plugins=started_plugins) async with stdio_server() as (read_stream, write_stream): await server.run( @@ -42,6 +45,7 @@ async def run_server() -> None: server.create_initialization_options(), ) finally: + await shutdown_plugins(started_plugins) await client.close() logger.info("Server shutdown complete") diff --git a/tests/test_plugin_discovery.py b/tests/test_plugin_discovery.py new file mode 100644 index 0000000..0a257b9 --- /dev/null +++ b/tests/test_plugin_discovery.py @@ -0,0 +1,273 @@ +"""Tests for MCP plugin discovery and integration.""" + +import logging +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mcp.types import TextContent, Tool + +from ros2_medkit_mcp.plugin import discover_plugins, shutdown_plugins, start_plugins + + +class FakePlugin: + @property + def name(self) -> str: + return "fake" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="fake_tool", + description="A fake tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, name: str, _arguments: dict[str, Any]) -> list[TextContent]: + if name == "fake_tool": + return [TextContent(type="text", text="fake result")] + raise ValueError(f"Unknown tool: {name}") + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + +class TestDiscoverPlugins: + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_discovers_installed_plugins(self, mock_eps: MagicMock) -> None: + mock_ep = MagicMock() + mock_ep.name = "fake" + mock_ep.value = "fake_package.plugin:FakePlugin" + mock_ep.load.return_value = FakePlugin + mock_eps.return_value = [mock_ep] + plugins = discover_plugins() + assert len(plugins) == 1 + assert plugins[0].name == "fake" + + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_no_plugins_installed(self, mock_eps: MagicMock) -> None: + mock_eps.return_value = [] + plugins = discover_plugins() + assert plugins == [] + + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_broken_plugin_skipped(self, mock_eps: MagicMock) -> None: + mock_ep = MagicMock() + mock_ep.name = "broken" + mock_ep.load.side_effect = ImportError("no such module") + mock_eps.return_value = [mock_ep] + plugins = discover_plugins() + assert plugins == [] + + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_non_conforming_plugin_skipped(self, mock_eps: MagicMock) -> None: + """Plugin without required attributes is skipped with warning.""" + + class BadPlugin: + pass + + mock_ep = MagicMock() + mock_ep.name = "bad" + mock_ep.load.return_value = BadPlugin + mock_eps.return_value = [mock_ep] + plugins = discover_plugins() + assert plugins == [] + + +class TestPluginLifecycle: + @pytest.mark.asyncio + async def test_start_plugins_returns_started(self) -> None: + plugin = FakePlugin() + started = await start_plugins([plugin]) + assert len(started) == 1 + assert started[0] is plugin + + @pytest.mark.asyncio + async def test_start_plugins_skips_failed(self) -> None: + good = FakePlugin() + bad = MagicMock() + bad.name = "bad" + bad.startup = AsyncMock(side_effect=RuntimeError("init failed")) + started = await start_plugins([good, bad]) + assert len(started) == 1 + assert started[0] is good + + @pytest.mark.asyncio + async def test_shutdown_plugins_calls_all(self) -> None: + p1 = MagicMock() + p1.name = "p1" + p1.shutdown = AsyncMock() + p2 = MagicMock() + p2.name = "p2" + p2.shutdown = AsyncMock() + await shutdown_plugins([p1, p2]) + p1.shutdown.assert_awaited_once() + p2.shutdown.assert_awaited_once() + + @pytest.mark.asyncio + async def test_shutdown_plugins_continues_on_error(self) -> None: + p1 = MagicMock() + p1.name = "p1" + p1.shutdown = AsyncMock(side_effect=RuntimeError("boom")) + p2 = MagicMock() + p2.name = "p2" + p2.shutdown = AsyncMock() + await shutdown_plugins([p1, p2]) + p2.shutdown.assert_awaited_once() + + +class TestPluginToolRegistration: + """Tests for plugin tool registration and dispatch in mcp_app.register_tools.""" + + def _make_server_mock(self) -> tuple[MagicMock, dict[str, Any]]: + """Create a mock Server that captures registered handlers.""" + server = MagicMock() + handlers: dict[str, Any] = {} + + def list_tools_decorator(): + def wrapper(fn: Any) -> Any: + handlers["list_tools"] = fn + return fn + + return wrapper + + def call_tool_decorator(): + def wrapper(fn: Any) -> Any: + handlers["call_tool"] = fn + return fn + + return wrapper + + server.list_tools = list_tools_decorator + server.call_tool = call_tool_decorator + return server, handlers + + @pytest.mark.asyncio + async def test_plugin_tool_dispatch(self) -> None: + """Plugin tools are dispatched via plugin_tool_map.""" + from ros2_medkit_mcp.mcp_app import register_tools + + server, handlers = self._make_server_mock() + client = MagicMock() + plugin = FakePlugin() + + register_tools(server, client, plugins=[plugin]) + + # list_tools should include the plugin tool + tools = await handlers["list_tools"]() + tool_names = {t.name for t in tools} + assert "fake_tool" in tool_names + + # call_tool should dispatch to plugin + result = await handlers["call_tool"]("fake_tool", {}) + assert len(result) == 1 + assert result[0].text == "fake result" + + @pytest.mark.asyncio + async def test_builtin_collision_skipped(self, caplog: pytest.LogCaptureFixture) -> None: + """Plugin tool colliding with built-in is skipped with warning.""" + from ros2_medkit_mcp.mcp_app import register_tools + + class CollidingPlugin: + @property + def name(self) -> str: + return "colliding" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="sovd_health", + description="Collides with built-in", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, _name: str, _arguments: dict[str, Any]) -> list[TextContent]: + return [TextContent(type="text", text="should not reach")] + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + server, handlers = self._make_server_mock() + client = MagicMock() + register_tools(server, client, plugins=[CollidingPlugin()]) + + with caplog.at_level(logging.WARNING): + tools = await handlers["list_tools"]() + + assert "collides with built-in tool" in caplog.text + + # sovd_health should appear exactly once (the built-in) + plugin_tool_names = [t.name for t in tools if t.name == "sovd_health"] + assert len(plugin_tool_names) == 1 + + @pytest.mark.asyncio + async def test_inter_plugin_collision_skipped(self, caplog: pytest.LogCaptureFixture) -> None: + """Second plugin declaring same tool name is skipped.""" + from ros2_medkit_mcp.mcp_app import register_tools + + class PluginA: + @property + def name(self) -> str: + return "plugin_a" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="shared_tool", + description="From A", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, _name: str, _arguments: dict[str, Any]) -> list[TextContent]: + return [TextContent(type="text", text="from A")] + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + class PluginB: + @property + def name(self) -> str: + return "plugin_b" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="shared_tool", + description="From B", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, _name: str, _arguments: dict[str, Any]) -> list[TextContent]: + return [TextContent(type="text", text="from B")] + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + server, handlers = self._make_server_mock() + client = MagicMock() + register_tools(server, client, plugins=[PluginA(), PluginB()]) + + with caplog.at_level(logging.WARNING): + await handlers["list_tools"]() + + assert "collides with another plugin tool" in caplog.text + + # Dispatch should go to plugin A (first registered) + result = await handlers["call_tool"]("shared_tool", {}) + assert result[0].text == "from A"