diff --git a/Cargo.lock b/Cargo.lock index f22dfe5..f35f206 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1147,6 +1147,7 @@ dependencies = [ "tracing", "tracing-core", "tracing-subscriber", + "typed-builder", "url", "uuid", "warp", @@ -2372,6 +2373,26 @@ dependencies = [ "utf-8", ] +[[package]] +name = "typed-builder" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce63bcaf7e9806c206f7d7b9c1f38e0dce8bb165a80af0898161058b19248534" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60d8d828da2a3d759d3519cdf29a5bac49c77d039ad36d0782edadbf9cd5415b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/Cargo.toml b/Cargo.toml index 533840f..73c2f06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ reqwest = { version = "0.12", features = ["json"] } reqwest-eventsource = "0.6" mime = "0.3" tokio-tungstenite = { version = "0.26", features = ["native-tls"] } +typed-builder = "0.21.0" tracing-core = "0.1" async-recursion = "1" http = "1.1" diff --git a/bin/client.rs b/bin/client.rs index cf01b07..6a01db4 100644 --- a/bin/client.rs +++ b/bin/client.rs @@ -2,9 +2,8 @@ use clap::{Parser, Subcommand}; use mcp_rs::{ client::{Client, ClientInfo}, error::McpError, - transport::sse::SseTransport, - transport::stdio::StdioTransport, - transport::ws::WebSocketTransport, + tools::CallToolArgs, + transport::{sse::SseTransport, stdio::StdioTransport, ws::WebSocketTransport}, }; use serde_json::json; @@ -69,6 +68,10 @@ enum Commands { name: String, #[arg(short, long)] args: String, + #[arg(short, long)] + tool_id: Option, + #[arg(short, long)] + session_id: Option, }, /// Set log level SetLogLevel { @@ -236,10 +239,26 @@ async fn main() -> Result<(), McpError> { println!("{}", json!(res)); } - Commands::CallTool { name, args } => { + Commands::CallTool { + name, + args, + tool_id, + session_id, + } => { let arguments = serde_json::from_str(&args).map_err(|e| McpError::InvalidRequest(e.to_string()))?; - let res = client.call_tool(name, arguments).await?; + let res = client + .call_tool( + name, + arguments, + Some( + CallToolArgs::builder() + .session_id(session_id) + .tool_id(tool_id) + .build(), + ), + ) + .await?; println!("{}", json!(res)); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 67bb13a..d3391c7 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -9,7 +9,10 @@ use crate::{ ListResourcesRequest, ListResourcesResponse, ReadResourceRequest, ReadResourceResponse, ResourceCapabilities, }, - tools::{CallToolRequest, ListToolsRequest, ListToolsResponse, ToolCapabilities, ToolResult}, + tools::{ + CallToolArgs, CallToolRequest, ListToolsRequest, ListToolsResponse, ToolCapabilities, + ToolResult, + }, transport::{Transport, TransportCommand}, }; use serde::{Deserialize, Serialize}; @@ -245,6 +248,7 @@ impl Client { &self, name: String, arguments: serde_json::Value, + metadata: Option, ) -> Result { self.assert_initialized().await?; self.assert_capability("tools").await?; @@ -252,7 +256,11 @@ impl Client { self.protocol .request( "tools/call", - Some(CallToolRequest { name, arguments }), + Some(CallToolRequest { + name, + arguments, + metadata, + }), None, ) .await diff --git a/src/server/mod.rs b/src/server/mod.rs index 3ef44ae..abbc52e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -347,7 +347,7 @@ where serde_json::from_value(req.params.unwrap_or_default()) .map_err(|_| McpError::InvalidParams)?; let result = tool_manager - .call_tool(¶ms.name, params.arguments) + .call_tool(¶ms.name, params.arguments, params.metadata) .await?; Ok(serde_json::to_value(result).unwrap()) }) diff --git a/src/tools/calculator.rs b/src/tools/calculator.rs index 14b849f..34ab525 100644 --- a/src/tools/calculator.rs +++ b/src/tools/calculator.rs @@ -7,7 +7,7 @@ use crate::{ protocol::{RequestHandler, ServerCapabilities}, }; -use super::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}; +use super::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}; // Domain types #[derive(Debug, serde::Deserialize)] @@ -205,7 +205,11 @@ impl ToolProvider for CalculatorTool { } } - async fn execute(&self, arguments: serde_json::Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let params: CalculatorParams = serde_json::from_value(arguments).map_err(|e| { tracing::error!("Error parsing calculator arguments: {:?}", e); McpError::InvalidParams @@ -301,6 +305,12 @@ mod tests { "a": 2.0, "b": 3.0 }), + Some( + CallToolArgs::builder() + .session_id(Some("session-1234".to_string())) + .tool_id(Some("calculator-1234".to_string())) + .build(), + ), ) .await .unwrap(); @@ -319,6 +329,12 @@ mod tests { "operation": "ln", "a": 2.718281828459045 }), + Some( + CallToolArgs::builder() + .session_id(Some("session-5678".to_string())) + .tool_id(Some("calculator-5678".to_string())) + .build(), + ), ) .await .unwrap(); @@ -350,6 +366,12 @@ mod tests { "a": -1.0, "b": 10.0 }), + Some( + CallToolArgs::builder() + .session_id(Some("session-9999".to_string())) + .tool_id(Some("calculator-9999".to_string())) + .build(), + ), ) .await .unwrap(); diff --git a/src/tools/file_system/directory.rs b/src/tools/file_system/directory.rs index 3fa8bba..2daabcd 100644 --- a/src/tools/file_system/directory.rs +++ b/src/tools/file_system/directory.rs @@ -5,7 +5,7 @@ use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct DirectoryTool; @@ -57,7 +57,11 @@ impl ToolProvider for DirectoryTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { match arguments["operation"].as_str() { Some("create_directory") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/mod.rs b/src/tools/file_system/mod.rs index e1c2bae..6654980 100644 --- a/src/tools/file_system/mod.rs +++ b/src/tools/file_system/mod.rs @@ -12,6 +12,8 @@ use serde_json::Value; use std::path::PathBuf; use std::sync::Arc; +use super::CallToolArgs; + #[derive(Clone)] pub struct FileSystemTools { read_tool: Arc, @@ -92,7 +94,11 @@ impl ToolProvider for FileSystemTools { tools.remove(0) } - async fn execute(&self, arguments: Value) -> Result { + async fn execute( + &self, + arguments: Value, + metadata: Option, + ) -> Result { // Add operation to list allowed directories if arguments["operation"].as_str() == Some("list_allowed_directories") { let dirs = self @@ -115,12 +121,14 @@ impl ToolProvider for FileSystemTools { .ok_or(McpError::InvalidParams)?; match operation { - "read_file" | "read_multiple_files" => self.read_tool.execute(arguments).await, - "write_file" => self.write_tool.execute(arguments).await, + "read_file" | "read_multiple_files" => { + self.read_tool.execute(arguments, metadata).await + } + "write_file" => self.write_tool.execute(arguments, metadata).await, "create_directory" | "list_directory" | "move_file" => { - self.directory_tool.execute(arguments).await + self.directory_tool.execute(arguments, metadata).await } - "search_files" | "get_file_info" => self.search_tool.execute(arguments).await, + "search_files" | "get_file_info" => self.search_tool.execute(arguments, metadata).await, _ => Err(McpError::InvalidParams), } } @@ -147,21 +155,37 @@ mod tests { // Test write operation let write_result = fs_tools - .execute(json!({ - "operation": "write_file", - "path": test_file.to_str().unwrap(), - "content": test_content, - })) + .execute( + json!({ + "operation": "write_file", + "path": test_file.to_str().unwrap(), + "content": test_content, + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); assert!(!write_result.is_error); // Test read operation let read_result = fs_tools - .execute(json!({ - "operation": "read_file", - "path": test_file.to_str().unwrap(), - })) + .execute( + json!({ + "operation": "read_file", + "path": test_file.to_str().unwrap(), + }), + Some( + CallToolArgs::builder() + .tool_id("read-from-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); @@ -178,20 +202,36 @@ mod tests { // Test directory creation let create_result = fs_tools - .execute(json!({ - "operation": "create_directory", - "path": test_dir.to_str().unwrap(), - })) + .execute( + json!({ + "operation": "create_directory", + "path": test_dir.to_str().unwrap(), + }), + Some( + CallToolArgs::builder() + .tool_id("create-directory-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); assert!(!create_result.is_error); // Test directory listing let list_result = fs_tools - .execute(json!({ - "operation": "list_directory", - "path": temp_dir.path().to_str().unwrap(), - })) + .execute( + json!({ + "operation": "list_directory", + "path": temp_dir.path().to_str().unwrap(), + }), + Some( + CallToolArgs::builder() + .tool_id("list-directory-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); @@ -210,22 +250,38 @@ mod tests { for file in &test_files { let path = temp_dir.path().join(file); fs_tools - .execute(json!({ - "operation": "write_file", - "path": path.to_str().unwrap(), - "content": "test content", - })) + .execute( + json!({ + "operation": "write_file", + "path": path.to_str().unwrap(), + "content": "test content", + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); } // Test search let search_result = fs_tools - .execute(json!({ - "operation": "search_files", - "path": temp_dir.path().to_str().unwrap(), - "pattern": "test", - })) + .execute( + json!({ + "operation": "search_files", + "path": temp_dir.path().to_str().unwrap(), + "pattern": "test", + }), + Some( + CallToolArgs::builder() + .tool_id("search-files-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); @@ -247,21 +303,37 @@ mod tests { // Create source file fs_tools - .execute(json!({ - "operation": "write_file", - "path": source.to_str().unwrap(), - "content": "test content", - })) + .execute( + json!({ + "operation": "write_file", + "path": source.to_str().unwrap(), + "content": "test content", + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); // Test move operation let move_result = fs_tools - .execute(json!({ - "operation": "move_file", - "source": source.to_str().unwrap(), - "destination": dest.to_str().unwrap(), - })) + .execute( + json!({ + "operation": "move_file", + "source": source.to_str().unwrap(), + "destination": dest.to_str().unwrap(), + }), + Some( + CallToolArgs::builder() + .tool_id("move-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); assert!(!move_result.is_error); @@ -278,11 +350,19 @@ mod tests { // Test invalid path let result = fs_tools - .execute(json!({ - "operation": "write_file", - "path": invalid_path, - "content": "test content", - })) + .execute( + json!({ + "operation": "write_file", + "path": invalid_path, + "content": "test content", + }), + Some( + CallToolArgs::builder() + .tool_id("write-to-file-1234".to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await; assert!(result.is_err()); @@ -297,20 +377,36 @@ mod tests { for (i, file) in files.iter().enumerate() { let path = temp_dir.path().join(file); fs_tools - .execute(json!({ - "operation": "write_file", - "path": path.to_str().unwrap(), - "content": format!("content {}", i), - })) + .execute( + json!({ + "operation": "write_file", + "path": path.to_str().unwrap(), + "content": format!("content {}", i), + }), + Some( + CallToolArgs::builder() + .tool_id(format!("write-to-file-{}", i).to_string()) + .session_id("session-1234".to_string()) + .build(), + ), + ) .await .unwrap(); } // Test reading multiple files - let read_result = fs_tools.execute(json!({ - "operation": "read_multiple_files", - "paths": files.iter().map(|f| temp_dir.path().join(f).to_str().unwrap().to_string()).collect::>(), - })).await.unwrap(); + let read_result = fs_tools.execute( + json!({ + "operation": "read_multiple_files", + "paths": files.iter().map(|f| temp_dir.path().join(f).to_str().unwrap().to_string()).collect::>(), + }), + Some(CallToolArgs::builder() + .tool_id("read-multiple-files-1234".to_string()) + .session_id("session-1234".to_string()) + .build()), + ) + .await + .unwrap(); assert_eq!(read_result.content.len(), 2); match &read_result.content[0] { diff --git a/src/tools/file_system/read.rs b/src/tools/file_system/read.rs index 0af362a..7087209 100644 --- a/src/tools/file_system/read.rs +++ b/src/tools/file_system/read.rs @@ -6,7 +6,7 @@ use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct ReadFileTool; @@ -83,7 +83,11 @@ impl ToolProvider for ReadFileTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { match arguments["operation"].as_str() { Some("read_file") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/search.rs b/src/tools/file_system/search.rs index e511e68..a161ec0 100644 --- a/src/tools/file_system/search.rs +++ b/src/tools/file_system/search.rs @@ -6,7 +6,7 @@ use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct SearchTool; @@ -111,7 +111,11 @@ impl ToolProvider for SearchTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { match arguments["operation"].as_str() { Some("search_files") => { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; diff --git a/src/tools/file_system/write.rs b/src/tools/file_system/write.rs index 726e73e..d34905e 100644 --- a/src/tools/file_system/write.rs +++ b/src/tools/file_system/write.rs @@ -5,7 +5,7 @@ use tokio::fs; use crate::{ error::McpError, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; pub struct WriteFileTool; @@ -58,7 +58,11 @@ impl ToolProvider for WriteFileTool { } } - async fn execute(&self, arguments: Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let path = arguments["path"].as_str().ok_or(McpError::InvalidParams)?; let content = arguments["content"] .as_str() diff --git a/src/tools/mod.rs b/src/tools/mod.rs index ae69b61..28e004e 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -4,6 +4,7 @@ use serde_json::Value; use std::{collections::HashMap, sync::Arc}; use test_tool::{PingTool, TestTool}; use tokio::sync::{mpsc, RwLock}; +use typed_builder::TypedBuilder; pub mod calculator; pub mod file_system; @@ -102,6 +103,20 @@ pub struct ListToolsResponse { pub struct CallToolRequest { pub name: String, pub arguments: Value, + #[serde(flatten)] + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize, TypedBuilder)] +#[serde(rename_all = "camelCase")] +pub struct CallToolArgs { + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default, setter(into))] + pub tool_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[builder(default, setter(into))] + pub session_id: Option, } // Tool Provider trait @@ -111,7 +126,11 @@ pub trait ToolProvider: Send + Sync { async fn get_tool(&self) -> Tool; /// Execute tool - async fn execute(&self, arguments: Value) -> Result; + async fn execute( + &self, + arguments: Value, + metadata: Option, + ) -> Result; } // Tool Manager @@ -264,12 +283,17 @@ impl ToolManager { }) } - pub async fn call_tool(&self, name: &str, arguments: Value) -> Result { + pub async fn call_tool( + &self, + name: &str, + arguments: Value, + metadata: Option, + ) -> Result { let tools = self.tools.read().await; let provider = tools .get(name) .ok_or_else(|| McpError::InvalidRequest(format!("Unknown tool: {}", name)))?; - provider.execute(arguments).await + provider.execute(arguments, metadata).await } } diff --git a/src/tools/test_tool.rs b/src/tools/test_tool.rs index 2fc3cad..95e8f3d 100644 --- a/src/tools/test_tool.rs +++ b/src/tools/test_tool.rs @@ -1,8 +1,9 @@ use async_trait::async_trait; +use serde_json::Value; use crate::error::McpError; -use super::{Tool, ToolContent, ToolProvider, ToolResult}; +use super::{CallToolArgs, Tool, ToolContent, ToolProvider, ToolResult}; pub struct TestTool; @@ -40,7 +41,11 @@ impl ToolProvider for TestTool { } } - async fn execute(&self, _arguments: serde_json::Value) -> Result { + async fn execute( + &self, + _arguments: Value, + _metadata: Option, + ) -> Result { Ok(ToolResult { content: vec![], is_error: false, @@ -85,7 +90,11 @@ impl ToolProvider for PingTool { } } - async fn execute(&self, arguments: serde_json::Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let server = arguments .get("server") .and_then(|s| s.as_str()) diff --git a/tests/tools.rs b/tests/tools.rs index 0dd4e0f..a1712ce 100644 --- a/tests/tools.rs +++ b/tests/tools.rs @@ -1,12 +1,12 @@ use async_trait::async_trait; -use serde_json::json; +use serde_json::{json, Value}; use std::{collections::HashMap, sync::Arc}; use mcp_rs::{ error::McpError, protocol::BasicRequestHandler, server::{config::ServerConfig, McpServer}, - tools::{Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, + tools::{CallToolArgs, Tool, ToolContent, ToolInputSchema, ToolProvider, ToolResult}, }; // Mock tool provider for testing @@ -51,7 +51,11 @@ impl ToolProvider for MockCalculatorTool { } } - async fn execute(&self, arguments: serde_json::Value) -> Result { + async fn execute( + &self, + arguments: Value, + _metadata: Option, + ) -> Result { let params: CalculatorParams = serde_json::from_value(arguments).map_err(|_| McpError::InvalidParams)?; @@ -129,6 +133,12 @@ async fn test_tool_execution() { "a": 5, "b": 3 }), + Some( + CallToolArgs::builder() + .session_id("calculator-id-1234".to_string()) + .tool_id("calculator-tool-id-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -149,6 +159,12 @@ async fn test_tool_execution() { "a": 1, "b": 0 }), + Some( + CallToolArgs::builder() + .session_id("calculator-id-1234".to_string()) + .tool_id("calculator-tool-id-1234".to_string()) + .build(), + ), ) .await .unwrap(); @@ -170,7 +186,7 @@ async fn test_invalid_tool() { // Test calling non-existent tool let result = server .tool_manager - .call_tool("nonexistent", json!({})) + .call_tool("nonexistent", json!({}), None) .await; assert!(result.is_err()); @@ -203,6 +219,12 @@ async fn test_invalid_arguments() { "a": 1, "b": 2 }), + Some( + CallToolArgs::builder() + .session_id("calculator-id-1234".to_string()) + .tool_id("calculator-tool-id-1234".to_string()) + .build(), + ), ) .await;