From 8fcfe856270fc33b9deaac0502b0657deec83365 Mon Sep 17 00:00:00 2001 From: Naveen Narayanan Date: Sat, 22 Mar 2025 15:10:36 -0700 Subject: [PATCH 1/4] [MCP] Allow tool_id to be passed in --- Cargo.lock | 21 ++++ Cargo.toml | 1 + bin/client.rs | 6 +- src/client/mod.rs | 7 +- src/server/mod.rs | 2 +- src/tools/calculator.rs | 9 +- src/tools/file_system/directory.rs | 7 +- src/tools/file_system/mod.rs | 162 ++++++++++++++++++----------- src/tools/file_system/read.rs | 7 +- src/tools/file_system/search.rs | 7 +- src/tools/file_system/write.rs | 7 +- src/tools/mod.rs | 28 ++++- src/tools/test_tool.rs | 7 +- tests/tools.rs | 11 +- 14 files changed, 194 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f22dfe5..f35f206 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1147,6 +1147,7 @@ dependencies = [ "tracing", "tracing-core", "tracing-subscriber", + "typed-builder", "url", "uuid", "warp", @@ -2372,6 +2373,26 @@ dependencies = [ "utf-8", ] +[[package]] +name = "typed-builder" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce63bcaf7e9806c206f7d7b9c1f38e0dce8bb165a80af0898161058b19248534" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60d8d828da2a3d759d3519cdf29a5bac49c77d039ad36d0782edadbf9cd5415b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/Cargo.toml b/Cargo.toml index 533840f..73c2f06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ reqwest = { version = "0.12", features = ["json"] } reqwest-eventsource = "0.6" mime = "0.3" tokio-tungstenite = { version = "0.26", features = ["native-tls"] } +typed-builder = "0.21.0" tracing-core = "0.1" async-recursion = "1" http = "1.1" diff --git a/bin/client.rs b/bin/client.rs index cf01b07..0802ac7 100644 --- a/bin/client.rs +++ b/bin/client.rs @@ -69,6 +69,8 @@ enum Commands { name: String, #[arg(short, long)] args: String, + #[arg(short, long)] + tool_id: Option, }, /// Set log level SetLogLevel { @@ -236,10 +238,10 @@ async fn main() -> Result<(), McpError> { println!("{}", json!(res)); } - Commands::CallTool { name, args } => { + Commands::CallTool { name, args , tool_id} => { let arguments = serde_json::from_str(&args).map_err(|e| McpError::InvalidRequest(e.to_string()))?; - let res = client.call_tool(name, arguments).await?; + let res = client.call_tool(name, arguments, tool_id).await?; println!("{}", json!(res)); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 67bb13a..6c44069 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -245,6 +245,7 @@ impl Client { &self, name: String, arguments: serde_json::Value, + tool_id: Option, ) -> Result { self.assert_initialized().await?; self.assert_capability("tools").await?; @@ -252,7 +253,11 @@ impl Client { self.protocol .request( "tools/call", - Some(CallToolRequest { name, arguments }), + Some(CallToolRequest { + name, + arguments, + tool_id, + }), None, ) .await diff --git a/src/server/mod.rs b/src/server/mod.rs index 3ef44ae..d8186c4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -347,7 +347,7 @@ where serde_json::from_value(req.params.unwrap_or_default()) .map_err(|_| McpError::InvalidParams)?; let result = tool_manager - .call_tool(¶ms.name, params.arguments) + .call_tool(¶ms.name, params.arguments, params.tool_id) .await?; Ok(serde_json::to_value(result).unwrap()) }) diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index 14b849f..69274ba 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -7,7 +7,7 @@ use crate::{ protocol::{RequestHandler, ServerCapabilities}, }; -use super::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}; +use super::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}; // Domain types #[derive(Debug, serde::Deserialize)] @@ -205,8 +205,8 @@ impl ToolProvider for CalculatorTool { } } - async fn execute(&self, arguments: serde_json::Value) -> Result { - let params: CalculatorParams = serde_json::from_value(arguments).map_err(|e| { + async fn execute(&self, arguments: CallToolArgs) -> Result { + let params: CalculatorParams = serde_json::from_value(arguments.arguments).map_err(|e| { tracing::error!("Error parsing calculator arguments: {:?}", e); McpError::InvalidParams })?; @@ -301,6 +301,7 @@ mod tests { "a": 2.0, "b": 3.0 }), + Some("calculator-1234".to_string()), ) .await .unwrap(); @@ -319,6 +320,7 @@ mod tests { "operation": "ln", "a": 2.718281828459045 }), + Some("calculator-5678".to_string()), ) .await .unwrap(); @@ -350,6 +352,7 @@ mod tests { "a": -1.0, "b": 10.0 }), + Some("calculator-9999".to_string()), ) .await .unwrap(); diff --git a/src/tools/file_system/directory.rs b/src/tools/file_system/directory.rs index 3fa8bba..e9535a9 100644 --- a/src/tools/file_system/directory.rs +++ b/src/tools/file_system/directory.rs @@ -1,11 +1,11 @@ use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::json; use std::collections::HashMap; use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct DirectoryTool; @@ -57,7 +57,8 @@ impl ToolProvider for DirectoryTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { + let arguments = &arguments.arguments; match arguments["operation"].as_str() { Some("create_directory") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/mod.rs b/src/tools/file_system/mod.rs index e1c2bae..c9185d1 100644 --- a/src/tools/file_system/mod.rs +++ b/src/tools/file_system/mod.rs @@ -8,10 +8,11 @@ use crate::{ tools::{Tool, ToolContent, ToolProvider, ToolResult}, }; use async_trait::async_trait; -use serde_json::Value; use std::path::PathBuf; use std::sync::Arc; +use super::CallToolArgs; + #[derive(Clone)] pub struct FileSystemTools { read_tool: Arc, @@ -92,9 +93,10 @@ impl ToolProvider for FileSystemTools { tools.remove(0) } - async fn execute(&self, arguments: Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { + let args = &arguments.arguments; // Add operation to list allowed directories - if arguments["operation"].as_str() == Some("list_allowed_directories") { + if args["operation"].as_str() == Some("list_allowed_directories") { let dirs = self .allowed_directories .iter() @@ -110,9 +112,7 @@ impl ToolProvider for FileSystemTools { } // Route to appropriate sub-tool based on operation type - let operation = arguments["operation"] - .as_str() - .ok_or(McpError::InvalidParams)?; + let operation = args["operation"].as_str().ok_or(McpError::InvalidParams)?; match operation { "read_file" | "read_multiple_files" => self.read_tool.execute(arguments).await, @@ -147,21 +147,30 @@ mod tests { // Test write operation let write_result = fs_tools - .execute(json!({ - "operation": "write_file", - "path": test_file.to_str().unwrap(), - "content": test_content, - })) + .execute( + CallToolArgs::builder() + .arguments(json!({ + "operation": "write_file", + "path": test_file.to_str().unwrap(), + "content": test_content, + })) + .tool_id(Some("write-to-file-1234".to_string())) + .build(), + ) .await .unwrap(); assert!(!write_result.is_error); // Test read operation let read_result = fs_tools - .execute(json!({ - "operation": "read_file", - "path": test_file.to_str().unwrap(), - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "read_file", + "path": test_file.to_str().unwrap(), + })) + .tool_id(Some("read-from-file-1234".to_string())) + .build(), + ) .await .unwrap(); @@ -178,20 +187,28 @@ mod tests { // Test directory creation let create_result = fs_tools - .execute(json!({ - "operation": "create_directory", - "path": test_dir.to_str().unwrap(), - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "create_directory", + "path": test_dir.to_str().unwrap(), + })) + .tool_id(Some("create-directory-1234".to_string())) + .build(), + ) .await .unwrap(); assert!(!create_result.is_error); // Test directory listing let list_result = fs_tools - .execute(json!({ - "operation": "list_directory", - "path": temp_dir.path().to_str().unwrap(), - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "list_directory", + "path": temp_dir.path().to_str().unwrap(), + })) + .tool_id(Some("list-directory-1234".to_string())) + .build(), + ) .await .unwrap(); @@ -210,22 +227,30 @@ mod tests { for file in &test_files { let path = temp_dir.path().join(file); fs_tools - .execute(json!({ - "operation": "write_file", - "path": path.to_str().unwrap(), - "content": "test content", - })) - .await - .unwrap(); + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "write_file", + "path": path.to_str().unwrap(), + "content": "test content", + })) + .tool_id(Some("write-to-file-1234".to_string())) + .build(), + ) + .await + .unwrap(); } // Test search let search_result = fs_tools - .execute(json!({ - "operation": "search_files", - "path": temp_dir.path().to_str().unwrap(), - "pattern": "test", - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "search_files", + "path": temp_dir.path().to_str().unwrap(), + "pattern": "test", + })) + .tool_id(Some("search-files-1234".to_string())) + .build(), + ) .await .unwrap(); @@ -247,21 +272,29 @@ mod tests { // Create source file fs_tools - .execute(json!({ - "operation": "write_file", - "path": source.to_str().unwrap(), - "content": "test content", - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "write_file", + "path": source.to_str().unwrap(), + "content": "test content", + })) + .tool_id(Some("write-to-file-1234".to_string())) + .build(), + ) .await .unwrap(); // Test move operation let move_result = fs_tools - .execute(json!({ - "operation": "move_file", - "source": source.to_str().unwrap(), - "destination": dest.to_str().unwrap(), - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "move_file", + "source": source.to_str().unwrap(), + "destination": dest.to_str().unwrap(), + })) + .tool_id(Some("move-file-1234".to_string())) + .build(), + ) .await .unwrap(); assert!(!move_result.is_error); @@ -278,11 +311,15 @@ mod tests { // Test invalid path let result = fs_tools - .execute(json!({ - "operation": "write_file", - "path": invalid_path, - "content": "test content", - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "write_file", + "path": invalid_path, + "content": "test content", + })) + .tool_id(Some("write-to-file-1234".to_string())) + .build(), + ) .await; assert!(result.is_err()); @@ -297,20 +334,27 @@ mod tests { for (i, file) in files.iter().enumerate() { let path = temp_dir.path().join(file); fs_tools - .execute(json!({ - "operation": "write_file", - "path": path.to_str().unwrap(), - "content": format!("content {}", i), - })) + .execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "write_file", + "path": path.to_str().unwrap(), + "content": format!("content {}", i), + })).build()) .await .unwrap(); } // Test reading multiple files - let read_result = fs_tools.execute(json!({ - "operation": "read_multiple_files", - "paths": files.iter().map(|f| temp_dir.path().join(f).to_str().unwrap().to_string()).collect::>(), - })).await.unwrap(); + let read_result = fs_tools.execute(CallToolArgs::builder() + .arguments(json!({ + "operation": "read_multiple_files", + "paths": files.iter().map(|f| temp_dir.path().join(f).to_str().unwrap().to_string()).collect::>(), + })) + .tool_id(Some("read-multiple-files-1234".to_string())) + .build(), + ) + .await + .unwrap(); assert_eq!(read_result.content.len(), 2); match &read_result.content[0] { diff --git a/src/tools/file_system/read.rs b/src/tools/file_system/read.rs index 0af362a..f5c2dcd 100644 --- a/src/tools/file_system/read.rs +++ b/src/tools/file_system/read.rs @@ -1,12 +1,12 @@ use async_trait::async_trait; use futures::future::try_join_all; -use serde_json::{json, Value}; +use serde_json::json; use std::collections::HashMap; use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct ReadFileTool; @@ -83,7 +83,8 @@ impl ToolProvider for ReadFileTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { + let arguments = arguments.arguments; match arguments["operation"].as_str() { Some("read_file") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/search.rs b/src/tools/file_system/search.rs index e511e68..12a66e8 100644 --- a/src/tools/file_system/search.rs +++ b/src/tools/file_system/search.rs @@ -1,12 +1,12 @@ use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::json; use std::collections::HashMap; use std::path::PathBuf; use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct SearchTool; @@ -111,7 +111,8 @@ impl ToolProvider for SearchTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { + let arguments = arguments.arguments; match arguments["operation"].as_str() { Some("search_files") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/write.rs b/src/tools/file_system/write.rs index 726e73e..331ed9b 100644 --- a/src/tools/file_system/write.rs +++ b/src/tools/file_system/write.rs @@ -1,11 +1,11 @@ use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::json; use std::collections::HashMap; use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct WriteFileTool; @@ -58,7 +58,8 @@ impl ToolProvider for WriteFileTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { + let arguments = arguments.arguments; let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; let content = arguments["content"] .as_str() diff --git a/src/tools/mod.rs b/src/tools/mod.rs index ae69b61..cf321a5 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -4,6 +4,7 @@ use serde_json::Value; use std::{collections::HashMap, sync::Arc}; use test_tool::{PingTool, TestTool}; use tokio::sync::{mpsc, RwLock}; +use typed_builder::TypedBuilder; pub mod calculator; pub mod file_system; @@ -102,6 +103,15 @@ pub struct ListToolsResponse { pub struct CallToolRequest { pub name: String, pub arguments: Value, + pub tool_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, TypedBuilder)] +pub struct CallToolArgs { + pub arguments: Value, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default, setter(into))] + pub tool_id: Option, } // Tool Provider trait @@ -111,7 +121,7 @@ pub trait ToolProvider: Send + Sync { async fn get_tool(&self) -> Tool; /// Execute tool - async fn execute(&self, arguments: Value) -> Result; + async fn execute(&self, arguments: CallToolArgs) -> Result; } // Tool Manager @@ -264,12 +274,24 @@ impl ToolManager { }) } - pub async fn call_tool(&self, name: &str, arguments: Value) -> Result { + pub async fn call_tool( + &self, + name: &str, + arguments: Value, + tool_id: Option, + ) -> Result { let tools = self.tools.read().await; let provider = tools .get(name) .ok_or_else(|| McpError::InvalidRequest(format!("Unknown tool: {}", name)))?; - provider.execute(arguments).await + provider + .execute( + CallToolArgs::builder() + .arguments(arguments) + .tool_id(tool_id) + .build(), + ) + .await } } diff --git a/src/tools/test_tool.rs b/src/tools/test_tool.rs index 2fc3cad..2c0e73a 100644 --- a/src/tools/test_tool.rs +++ b/src/tools/test_tool.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use crate::error::McpError; -use super::{Tool, ToolContent, ToolProvider, ToolResult}; +use super::{CallToolArgs, Tool, ToolContent, ToolProvider, ToolResult}; pub struct TestTool; @@ -40,7 +40,7 @@ impl ToolProvider for TestTool { } } - async fn execute(&self, _arguments: serde_json::Value) -> Result { + async fn execute(&self, _arguments: CallToolArgs) -> Result { Ok(ToolResult { content: vec![], is_error: false, @@ -85,8 +85,9 @@ impl ToolProvider for PingTool { } } - async fn execute(&self, arguments: serde_json::Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { let server = arguments + .arguments .get("server") .and_then(|s| s.as_str()) .unwrap_or("localhost"); diff --git a/tests/tools.rs b/tests/tools.rs index 0dd4e0f..9520038 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -6,7 +6,7 @@ use mcp_rs::{ error::McpError, protocol::BasicRequestHandler, server::{config::ServerConfig, McpServer}, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; // Mock tool provider for testing @@ -51,9 +51,9 @@ impl ToolProvider for MockCalculatorTool { } } - async fn execute(&self, arguments: serde_json::Value) -> Result { + async fn execute(&self, arguments: CallToolArgs) -> Result { let params: CalculatorParams = - serde_json::from_value(arguments).map_err(|_| McpError::InvalidParams)?; + serde_json::from_value(arguments.arguments).map_err(|_| McpError::InvalidParams)?; let result = match params.operation.as_str() { "add" => params.a + params.b, @@ -129,6 +129,7 @@ async fn test_tool_execution() { "a": 5, "b": 3 }), + Some("calculator-id-1234".to_string()), ) .await .unwrap(); @@ -149,6 +150,7 @@ async fn test_tool_execution() { "a": 1, "b": 0 }), + Some("calculator-id-2345".to_string()), ) .await .unwrap(); @@ -170,7 +172,7 @@ async fn test_invalid_tool() { // Test calling non-existent tool let result = server .tool_manager - .call_tool("nonexistent", json!({})) + .call_tool("nonexistent", json!({}), None) .await; assert!(result.is_err()); @@ -203,6 +205,7 @@ async fn test_invalid_arguments() { "a": 1, "b": 2 }), + Some("calculator-id-3456".to_string()), ) .await; From 97e39d066fedb5568d7fb5b2f070be1ca1b8976e Mon Sep 17 00:00:00 2001 From: Naveen Narayanan Date: Sat, 22 Mar 2025 22:23:34 -0700 Subject: [PATCH 2/4] Wrapping in metadata model --- bin/client.rs | 27 ++++- src/client/mod.rs | 9 +- src/server/mod.rs | 2 +- src/tools/calculator.rs | 29 ++++- src/tools/file_system/directory.rs | 9 +- src/tools/file_system/mod.rs | 186 ++++++++++++++++++----------- src/tools/file_system/read.rs | 9 +- src/tools/file_system/search.rs | 9 +- src/tools/file_system/write.rs | 9 +- src/tools/mod.rs | 24 ++-- src/tools/test_tool.rs | 14 ++- tests/tools.rs | 31 ++++- 12 files changed, 244 insertions(+), 114 deletions(-) diff --git a/bin/client.rs b/bin/client.rs index 0802ac7..6a01db4 100644 --- a/bin/client.rs +++ b/bin/client.rs @@ -2,9 +2,8 @@ use clap::{Parser, Subcommand}; use mcp_rs::{ client::{Client, ClientInfo}, error::McpError, - transport::sse::SseTransport, - transport::stdio::StdioTransport, - transport::ws::WebSocketTransport, + tools::CallToolArgs, + transport::{sse::SseTransport, stdio::StdioTransport, ws::WebSocketTransport}, }; use serde_json::json; @@ -71,6 +70,8 @@ enum Commands { args: String, #[arg(short, long)] tool_id: Option, + #[arg(short, long)] + session_id: Option, }, /// Set log level SetLogLevel { @@ -238,10 +239,26 @@ async fn main() -> Result<(), McpError> { println!("{}", json!(res)); } - Commands::CallTool { name, args , tool_id} => { + Commands::CallTool { + name, + args, + tool_id, + session_id, + } => { let arguments = serde_json::from_str(&args).map_err(|e| McpError::InvalidRequest(e.to_string()))?; - let res = client.call_tool(name, arguments, tool_id).await?; + let res = client + .call_tool( + name, + arguments, + Some( + CallToolArgs::builder() + .session_id(session_id) + .tool_id(tool_id) + .build(), + ), + ) + .await?; println!("{}", json!(res)); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 6c44069..d3391c7 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -9,7 +9,10 @@ use crate::{ ListResourcesRequest, ListResourcesResponse, ReadResourceRequest, ReadResourceResponse, ResourceCapabilities, }, - tools::{CallToolRequest, ListToolsRequest, ListToolsResponse, ToolCapabilities, ToolResult}, + tools::{ + CallToolArgs, CallToolRequest, ListToolsRequest, ListToolsResponse, ToolCapabilities, + ToolResult, + }, transport::{Transport, TransportCommand}, }; use serde::{Deserialize, Serialize}; @@ -245,7 +248,7 @@ impl Client { &self, name: String, arguments: serde_json::Value, - tool_id: Option, + metadata: Option, ) -> Result { self.assert_initialized().await?; self.assert_capability("tools").await?; @@ -256,7 +259,7 @@ impl Client { Some(CallToolRequest { name, arguments, - tool_id, + metadata, }), None, ) diff --git a/src/server/mod.rs b/src/server/mod.rs index d8186c4..abbc52e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -347,7 +347,7 @@ where serde_json::from_value(req.params.unwrap_or_default()) .map_err(|_| McpError::InvalidParams)?; let result = tool_manager - .call_tool(¶ms.name, params.arguments, params.tool_id) + .call_tool(¶ms.name, params.arguments, params.metadata) .await?; Ok(serde_json::to_value(result).unwrap()) }) diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index 69274ba..34ab525 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -205,8 +205,12 @@ impl ToolProvider for CalculatorTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { - let params: CalculatorParams = serde_json::from_value(arguments.arguments).map_err(|e| { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { + let params: CalculatorParams = serde_json::from_value(arguments).map_err(|e| { tracing::error!("Error parsing calculator arguments: {:?}", e); McpError::InvalidParams })?; @@ -301,7 +305,12 @@ mod tests { "a": 2.0, "b": 3.0 }), - Some("calculator-1234".to_string()), + Some( + CallToolArgs::builder() + .session_id(Some("session-1234".to_string())) + .tool_id(Some("calculator-1234".to_string())) + .build(), + ), ) .await .unwrap(); @@ -320,7 +329,12 @@ mod tests { "operation": "ln", "a": 2.718281828459045 }), - Some("calculator-5678".to_string()), + Some( + CallToolArgs::builder() + .session_id(Some("session-5678".to_string())) + .tool_id(Some("calculator-5678".to_string())) + .build(), + ), ) .await .unwrap(); @@ -352,7 +366,12 @@ mod tests { "a": -1.0, "b": 10.0 }), - Some("calculator-9999".to_string()), + Some( + CallToolArgs::builder() + .session_id(Some("session-9999".to_string())) + .tool_id(Some("calculator-9999".to_string())) + .build(), + ), ) .await .unwrap(); diff --git a/src/tools/file_system/directory.rs b/src/tools/file_system/directory.rs index e9535a9..2daabcd 100644 --- a/src/tools/file_system/directory.rs +++ b/src/tools/file_system/directory.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; use tokio::fs; @@ -57,8 +57,11 @@ impl ToolProvider for DirectoryTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { - let arguments = &arguments.arguments; + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { match arguments["operation"].as_str() { Some("create_directory") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/mod.rs b/src/tools/file_system/mod.rs index c9185d1..6654980 100644 --- a/src/tools/file_system/mod.rs +++ b/src/tools/file_system/mod.rs @@ -8,6 +8,7 @@ use crate::{ tools::{Tool, ToolContent, ToolProvider, ToolResult}, }; use async_trait::async_trait; +use serde_json::Value; use std::path::PathBuf; use std::sync::Arc; @@ -93,10 +94,13 @@ impl ToolProvider for FileSystemTools { tools.remove(0) } - async fn execute(&self, arguments: CallToolArgs) -> Result { - let args = &arguments.arguments; + async fn execute( + &self, + arguments: Value, + metadata: Option, + ) -> Result { // Add operation to list allowed directories - if args["operation"].as_str() == Some("list_allowed_directories") { + if arguments["operation"].as_str() == Some("list_allowed_directories") { let dirs = self .allowed_directories .iter() @@ -112,15 +116,19 @@ impl ToolProvider for FileSystemTools { } // Route to appropriate sub-tool based on operation type - let operation = args["operation"].as_str().ok_or(McpError::InvalidParams)?; + let operation = arguments["operation"] + .as_str() + .ok_or(McpError::InvalidParams)?; match operation { - "read_file" | "read_multiple_files" => self.read_tool.execute(arguments).await, - "write_file" => self.write_tool.execute(arguments).await, + "read_file" | "read_multiple_files" => { + self.read_tool.execute(arguments, metadata).await + } + "write_file" => self.write_tool.execute(arguments, metadata).await, "create_directory" | "list_directory" | "move_file" => { - self.directory_tool.execute(arguments).await + self.directory_tool.execute(arguments, metadata).await } - "search_files" | "get_file_info" => self.search_tool.execute(arguments).await, + "search_files" | "get_file_info" => self.search_tool.execute(arguments, metadata).await, _ => Err(McpError::InvalidParams), } } @@ -148,14 +156,17 @@ mod tests { // Test write operation let write_result = fs_tools .execute( - CallToolArgs::builder() - .arguments(json!({ - "operation": "write_file", - "path": test_file.to_str().unwrap(), - "content": test_content, - })) - .tool_id(Some("write-to-file-1234".to_string())) - .build(), + json!({ + "operation": "write_file", + "path": test_file.to_str().unwrap(), + "content": test_content, + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -163,13 +174,17 @@ mod tests { // Test read operation let read_result = fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "read_file", "path": test_file.to_str().unwrap(), - })) - .tool_id(Some("read-from-file-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("read-from-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -187,13 +202,17 @@ mod tests { // Test directory creation let create_result = fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "create_directory", "path": test_dir.to_str().unwrap(), - })) - .tool_id(Some("create-directory-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("create-directory-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -201,13 +220,17 @@ mod tests { // Test directory listing let list_result = fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "list_directory", "path": temp_dir.path().to_str().unwrap(), - })) - .tool_id(Some("list-directory-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("list-directory-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -227,29 +250,37 @@ mod tests { for file in &test_files { let path = temp_dir.path().join(file); fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "write_file", "path": path.to_str().unwrap(), "content": "test content", - })) - .tool_id(Some("write-to-file-1234".to_string())) - .build(), - ) - .await - .unwrap(); + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) + .await + .unwrap(); } // Test search let search_result = fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "search_files", "path": temp_dir.path().to_str().unwrap(), "pattern": "test", - })) - .tool_id(Some("search-files-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("search-files-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -272,28 +303,36 @@ mod tests { // Create source file fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "write_file", "path": source.to_str().unwrap(), "content": "test content", - })) - .tool_id(Some("write-to-file-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); // Test move operation let move_result = fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "move_file", "source": source.to_str().unwrap(), "destination": dest.to_str().unwrap(), - })) - .tool_id(Some("move-file-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("move-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -311,14 +350,18 @@ mod tests { // Test invalid path let result = fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "write_file", "path": invalid_path, "content": "test content", - })) - .tool_id(Some("write-to-file-1234".to_string())) - .build(), + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), ) .await; @@ -334,24 +377,33 @@ mod tests { for (i, file) in files.iter().enumerate() { let path = temp_dir.path().join(file); fs_tools - .execute(CallToolArgs::builder() - .arguments(json!({ + .execute( + json!({ "operation": "write_file", "path": path.to_str().unwrap(), "content": format!("content {}", i), - })).build()) + }), + Some( + CallToolArgs::builder() + .tool_id(format!("write-to-file-{}", i).to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); } // Test reading multiple files - let read_result = fs_tools.execute(CallToolArgs::builder() - .arguments(json!({ + let read_result = fs_tools.execute( + json!({ "operation": "read_multiple_files", "paths": files.iter().map(|f| temp_dir.path().join(f).to_str().unwrap().to_string()).collect::>(), - })) - .tool_id(Some("read-multiple-files-1234".to_string())) - .build(), + }), + Some(CallToolArgs::builder() + .tool_id("read-multiple-files-1234".to_string()) + .session_id("session-1234".to_string()) + .build()), ) .await .unwrap(); diff --git a/src/tools/file_system/read.rs b/src/tools/file_system/read.rs index f5c2dcd..7087209 100644 --- a/src/tools/file_system/read.rs +++ b/src/tools/file_system/read.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use futures::future::try_join_all; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; use tokio::fs; @@ -83,8 +83,11 @@ impl ToolProvider for ReadFileTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { - let arguments = arguments.arguments; + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { match arguments["operation"].as_str() { Some("read_file") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/search.rs b/src/tools/file_system/search.rs index 12a66e8..a161ec0 100644 --- a/src/tools/file_system/search.rs +++ b/src/tools/file_system/search.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; use std::path::PathBuf; use tokio::fs; @@ -111,8 +111,11 @@ impl ToolProvider for SearchTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { - let arguments = arguments.arguments; + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { match arguments["operation"].as_str() { Some("search_files") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/write.rs b/src/tools/file_system/write.rs index 331ed9b..d34905e 100644 --- a/src/tools/file_system/write.rs +++ b/src/tools/file_system/write.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; use tokio::fs; @@ -58,8 +58,11 @@ impl ToolProvider for WriteFileTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { - let arguments = arguments.arguments; + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; let content = arguments["content"] .as_str() diff --git a/src/tools/mod.rs b/src/tools/mod.rs index cf321a5..f5a8f81 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -103,15 +103,18 @@ pub struct ListToolsResponse { pub struct CallToolRequest { pub name: String, pub arguments: Value, - pub tool_id: Option, + #[serde(flatten)] + pub metadata: Option, } #[derive(Debug, Serialize, Deserialize, TypedBuilder)] pub struct CallToolArgs { - pub arguments: Value, #[serde(skip_serializing_if = "Option::is_none")] #[builder(default, setter(into))] pub tool_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default, setter(into))] + pub session_id: Option, } // Tool Provider trait @@ -121,7 +124,11 @@ pub trait ToolProvider: Send + Sync { async fn get_tool(&self) -> Tool; /// Execute tool - async fn execute(&self, arguments: CallToolArgs) -> Result; + async fn execute( + &self, + arguments: Value, + metadata: Option, + ) -> Result; } // Tool Manager @@ -278,20 +285,13 @@ impl ToolManager { &self, name: &str, arguments: Value, - tool_id: Option, + metadata: Option, ) -> Result { let tools = self.tools.read().await; let provider = tools .get(name) .ok_or_else(|| McpError::InvalidRequest(format!("Unknown tool: {}", name)))?; - provider - .execute( - CallToolArgs::builder() - .arguments(arguments) - .tool_id(tool_id) - .build(), - ) - .await + provider.execute(arguments, metadata).await } } diff --git a/src/tools/test_tool.rs b/src/tools/test_tool.rs index 2c0e73a..95e8f3d 100644 --- a/src/tools/test_tool.rs +++ b/src/tools/test_tool.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use serde_json::Value; use crate::error::McpError; @@ -40,7 +41,11 @@ impl ToolProvider for TestTool { } } - async fn execute(&self, _arguments: CallToolArgs) -> Result { + async fn execute( + &self, + _arguments: Value, + _metadata: Option, + ) -> Result { Ok(ToolResult { content: vec![], is_error: false, @@ -85,9 +90,12 @@ impl ToolProvider for PingTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let server = arguments - .arguments .get("server") .and_then(|s| s.as_str()) .unwrap_or("localhost"); diff --git a/tests/tools.rs b/tests/tools.rs index 9520038..a1712ce 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use serde_json::json; +use serde_json::{json, Value}; use std::{collections::HashMap, sync::Arc}; use mcp_rs::{ @@ -51,9 +51,13 @@ impl ToolProvider for MockCalculatorTool { } } - async fn execute(&self, arguments: CallToolArgs) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let params: CalculatorParams = - serde_json::from_value(arguments.arguments).map_err(|_| McpError::InvalidParams)?; + serde_json::from_value(arguments).map_err(|_| McpError::InvalidParams)?; let result = match params.operation.as_str() { "add" => params.a + params.b, @@ -129,7 +133,12 @@ async fn test_tool_execution() { "a": 5, "b": 3 }), - Some("calculator-id-1234".to_string()), + Some( + CallToolArgs::builder() + .session_id("calculator-id-1234".to_string()) + .tool_id("calculator-tool-id-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -150,7 +159,12 @@ async fn test_tool_execution() { "a": 1, "b": 0 }), - Some("calculator-id-2345".to_string()), + Some( + CallToolArgs::builder() + .session_id("calculator-id-1234".to_string()) + .tool_id("calculator-tool-id-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -205,7 +219,12 @@ async fn test_invalid_arguments() { "a": 1, "b": 2 }), - Some("calculator-id-3456".to_string()), + Some( + CallToolArgs::builder() + .session_id("calculator-id-1234".to_string()) + .tool_id("calculator-tool-id-1234".to_string()) + .build(), + ), ) .await; From 43eb26e68bdf590a02fedb1da94dfcc4548c2b12 Mon Sep 17 00:00:00 2001 From: Naveen Narayanan Date: Sat, 22 Mar 2025 23:16:58 -0700 Subject: [PATCH 3/4] Working camel Case --- src/tools/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tools/mod.rs b/src/tools/mod.rs index f5a8f81..7af5bc2 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -103,11 +103,11 @@ pub struct ListToolsResponse { pub struct CallToolRequest { pub name: String, pub arguments: Value, - #[serde(flatten)] pub metadata: Option, } #[derive(Debug, Serialize, Deserialize, TypedBuilder)] +#[serde(rename_all = "camelCase")] pub struct CallToolArgs { #[serde(skip_serializing_if = "Option::is_none")] #[builder(default, setter(into))] From 9a62acadeb93ec9bd34da041eca39215a138635a Mon Sep 17 00:00:00 2001 From: Naveen Narayanan Date: Sat, 22 Mar 2025 23:23:32 -0700 Subject: [PATCH 4/4] Flatten --- src/tools/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 7af5bc2..28e004e 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -103,6 +103,8 @@ pub struct ListToolsResponse { pub struct CallToolRequest { pub name: String, pub arguments: Value, + #[serde(flatten)] + #[serde(skip_serializing_if = "Option::is_none")] pub metadata: Option, }