diff --git a/src/server/mod.rs b/src/server/mod.rs index 64a826f..3ef44ae 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -98,7 +98,7 @@ where pub prompt_manager: Arc, pub logging_manager: Arc>, notification_tx: mpsc::Sender, - notification_rx: Option>, // Make this Option + notification_rx: Option>, state: Arc<(watch::Sender, watch::Receiver)>, supported_versions: Vec, client_capabilities: Arc>>, @@ -127,6 +127,7 @@ where })), tool_manager: Arc::new(ToolManager::new(ToolCapabilities { list_changed: config + .clone() .capabilities .as_ref() .is_some_and(|c| c.tools.as_ref().is_some_and(|t| t.list_changed)), @@ -146,6 +147,26 @@ where } } + // Initialize the server with notification support + pub fn initialize(&mut self) { + // Create a new tool manager with notification support + if self + .config + .capabilities + .as_ref() + .is_some_and(|c| c.tools.as_ref().is_some_and(|t| t.list_changed)) + { + let notification_tx = self.notification_tx.clone(); + let tool_capabilities = ToolCapabilities { list_changed: true }; + + // Replace the tool manager with one that has notification support + self.tool_manager = Arc::new(ToolManager::with_notification_sender( + tool_capabilities, + notification_tx, + )); + } + } + pub async fn process_request( &self, method: &str, @@ -155,11 +176,13 @@ where } pub async fn run_stdio_transport(&mut self) -> Result<(), McpError> { + self.initialize(); let transport = StdioTransport::new(Some(1024)); self.run_transport(transport).await } pub async fn run_sse_transport(&mut self) -> Result<(), McpError> { + self.initialize(); let transport = SseTransport::new_server( self.config.server.host.clone(), self.config.server.port, @@ -169,6 +192,7 @@ where } pub async fn run_websocket_transport(&mut self) -> Result<(), McpError> { + self.initialize(); let transport = WebSocketTransport::new_server( self.config.server.host.clone(), self.config.server.port, @@ -180,7 +204,7 @@ where #[cfg(unix)] pub async fn run_unix_transport(&mut self) -> Result<(), McpError> { tracing::info!("Starting Unix transport"); - + self.initialize(); let transport = crate::transport::unix::UnixTransport::new_server( PathBuf::from(&self.config.server.host), Some(1024), diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 8fd90ff..0a8290c 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -3,13 +3,14 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::{collections::HashMap, sync::Arc}; use test_tool::{PingTool, TestTool}; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; pub mod calculator; pub mod file_system; pub mod test_tool; use crate::error::McpError; +use crate::protocol::JsonRpcNotification; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -115,15 +116,44 @@ pub struct ToolCapabilities { pub list_changed: bool, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolUpdateNotification { + pub tool_name: String, + pub update_type: ToolUpdateType, + pub details: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolUpdateType { + Added, + Updated, + Removed, +} + pub struct ToolManager { pub tools: Arc>>>, pub capabilities: ToolCapabilities, + notification_tx: Option>, } impl ToolManager { pub fn new(capabilities: ToolCapabilities) -> Self { Self { tools: Arc::new(RwLock::new(HashMap::new())), + notification_tx: None, + capabilities, + } + } + + pub fn with_notification_sender( + capabilities: ToolCapabilities, + notification_tx: mpsc::Sender, + ) -> Self { + Self { + tools: Arc::new(RwLock::new(HashMap::new())), + notification_tx: Some(notification_tx), capabilities, } } @@ -131,7 +161,88 @@ impl ToolManager { pub async fn register_tool(&self, provider: Arc) { let tool = provider.get_tool().await; let mut tools = self.tools.write().await; - tools.insert(tool.name, provider); + tools.insert(tool.name.clone(), provider); + + // Send notification if tool updates are enabled + if self.capabilities.list_changed { + self.send_tool_update_notification( + &tool.name, + ToolUpdateType::Added, + Some(format!("Tool '{}' registered", tool.name)), + ) + .await; + } + } + + pub async fn unregister_tool(&self, name: &str) -> Result<(), McpError> { + let mut tools = self.tools.write().await; + if tools.remove(name).is_some() { + // Send notification if tool updates are enabled + if self.capabilities.list_changed { + self.send_tool_update_notification( + name, + ToolUpdateType::Removed, + Some(format!("Tool '{}' unregistered", name)), + ) + .await; + } + Ok(()) + } else { + Err(McpError::InvalidRequest(format!( + "Tool '{}' not found", + name + ))) + } + } + + pub async fn update_tool(&self, provider: Arc) -> Result<(), McpError> { + let tool = provider.get_tool().await; + let mut tools = self.tools.write().await; + + if tools.contains_key(&tool.name) { + tools.insert(tool.name.clone(), provider); + + // Send notification if tool updates are enabled + if self.capabilities.list_changed { + self.send_tool_update_notification( + &tool.name, + ToolUpdateType::Updated, + Some(format!("Tool '{}' updated", tool.name)), + ) + .await; + } + Ok(()) + } else { + Err(McpError::InvalidRequest(format!( + "Tool '{}' not found", + tool.name + ))) + } + } + + async fn send_tool_update_notification( + &self, + tool_name: &str, + update_type: ToolUpdateType, + details: Option, + ) { + if let Some(tx) = &self.notification_tx { + let notification = ToolUpdateNotification { + tool_name: tool_name.to_string(), + update_type, + details, + }; + + let json_notification = JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "tools/update".to_string(), + params: Some(serde_json::to_value(notification).unwrap_or_default()), + }; + + if let Err(e) = tx.send(json_notification).await { + tracing::error!("Failed to send tool update notification: {}", e); + } + } } pub async fn list_tools(&self, _cursor: Option) -> Result {