Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ where
pub prompt_manager: Arc<PromptManager>,
pub logging_manager: Arc<tokio::sync::Mutex<LoggingManager>>,
notification_tx: mpsc::Sender<JsonRpcNotification>,
notification_rx: Option<mpsc::Receiver<JsonRpcNotification>>, // Make this Option
notification_rx: Option<mpsc::Receiver<JsonRpcNotification>>,
state: Arc<(watch::Sender<ServerState>, watch::Receiver<ServerState>)>,
supported_versions: Vec<String>,
client_capabilities: Arc<RwLock<Option<ClientCapabilities>>>,
Expand Down Expand Up @@ -127,6 +127,7 @@ where
})),
tool_manager: Arc::new(ToolManager::new(ToolCapabilities {
list_changed: config
.clone()
.capabilities
.as_ref()
.is_some_and(|c| c.tools.as_ref().is_some_and(|t| t.list_changed)),
Expand All @@ -146,6 +147,26 @@ where
}
}

// Initialize the server with notification support
pub fn initialize(&mut self) {
// Create a new tool manager with notification support
if self
.config
.capabilities
.as_ref()
.is_some_and(|c| c.tools.as_ref().is_some_and(|t| t.list_changed))
{
let notification_tx = self.notification_tx.clone();
let tool_capabilities = ToolCapabilities { list_changed: true };

// Replace the tool manager with one that has notification support
self.tool_manager = Arc::new(ToolManager::with_notification_sender(
tool_capabilities,
notification_tx,
));
}
}

pub async fn process_request(
&self,
method: &str,
Expand All @@ -155,11 +176,13 @@ where
}

pub async fn run_stdio_transport(&mut self) -> Result<(), McpError> {
self.initialize();
let transport = StdioTransport::new(Some(1024));
self.run_transport(transport).await
}

pub async fn run_sse_transport(&mut self) -> Result<(), McpError> {
self.initialize();
let transport = SseTransport::new_server(
self.config.server.host.clone(),
self.config.server.port,
Expand All @@ -169,6 +192,7 @@ where
}

pub async fn run_websocket_transport(&mut self) -> Result<(), McpError> {
self.initialize();
let transport = WebSocketTransport::new_server(
self.config.server.host.clone(),
self.config.server.port,
Expand All @@ -180,7 +204,7 @@ where
#[cfg(unix)]
pub async fn run_unix_transport(&mut self) -> Result<(), McpError> {
tracing::info!("Starting Unix transport");

self.initialize();
let transport = crate::transport::unix::UnixTransport::new_server(
PathBuf::from(&self.config.server.host),
Some(1024),
Expand Down
115 changes: 113 additions & 2 deletions src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, sync::Arc};
use test_tool::{PingTool, TestTool};
use tokio::sync::RwLock;
use tokio::sync::{mpsc, RwLock};

pub mod calculator;
pub mod file_system;
pub mod test_tool;

use crate::error::McpError;
use crate::protocol::JsonRpcNotification;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
Expand Down Expand Up @@ -115,23 +116,133 @@ pub struct ToolCapabilities {
pub list_changed: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolUpdateNotification {
pub tool_name: String,
pub update_type: ToolUpdateType,
pub details: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolUpdateType {
Added,
Updated,
Removed,
}

pub struct ToolManager {
pub tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProvider>>>>,
pub capabilities: ToolCapabilities,
notification_tx: Option<mpsc::Sender<JsonRpcNotification>>,
}

impl ToolManager {
pub fn new(capabilities: ToolCapabilities) -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
notification_tx: None,
capabilities,
}
}

pub fn with_notification_sender(
capabilities: ToolCapabilities,
notification_tx: mpsc::Sender<JsonRpcNotification>,
) -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
notification_tx: Some(notification_tx),
capabilities,
}
}

pub async fn register_tool(&self, provider: Arc<dyn ToolProvider>) {
let tool = provider.get_tool().await;
let mut tools = self.tools.write().await;
tools.insert(tool.name, provider);
tools.insert(tool.name.clone(), provider);

// Send notification if tool updates are enabled
if self.capabilities.list_changed {
self.send_tool_update_notification(
&tool.name,
ToolUpdateType::Added,
Some(format!("Tool '{}' registered", tool.name)),
)
.await;
}
}

pub async fn unregister_tool(&self, name: &str) -> Result<(), McpError> {
let mut tools = self.tools.write().await;
if tools.remove(name).is_some() {
// Send notification if tool updates are enabled
if self.capabilities.list_changed {
self.send_tool_update_notification(
name,
ToolUpdateType::Removed,
Some(format!("Tool '{}' unregistered", name)),
)
.await;
}
Ok(())
} else {
Err(McpError::InvalidRequest(format!(
"Tool '{}' not found",
name
)))
}
}

pub async fn update_tool(&self, provider: Arc<dyn ToolProvider>) -> Result<(), McpError> {
let tool = provider.get_tool().await;
let mut tools = self.tools.write().await;

if tools.contains_key(&tool.name) {
tools.insert(tool.name.clone(), provider);

// Send notification if tool updates are enabled
if self.capabilities.list_changed {
self.send_tool_update_notification(
&tool.name,
ToolUpdateType::Updated,
Some(format!("Tool '{}' updated", tool.name)),
)
.await;
}
Ok(())
} else {
Err(McpError::InvalidRequest(format!(
"Tool '{}' not found",
tool.name
)))
}
}

async fn send_tool_update_notification(
&self,
tool_name: &str,
update_type: ToolUpdateType,
details: Option<String>,
) {
if let Some(tx) = &self.notification_tx {
let notification = ToolUpdateNotification {
tool_name: tool_name.to_string(),
update_type,
details,
};

let json_notification = JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "tools/update".to_string(),
params: Some(serde_json::to_value(notification).unwrap_or_default()),
};

if let Err(e) = tx.send(json_notification).await {
tracing::error!("Failed to send tool update notification: {}", e);
}
}
}

pub async fn list_tools(&self, _cursor: Option<String>) -> Result<ListToolsResponse, McpError> {
Expand Down