From a53060e97fc26575a33bdadb1b40b726ca89f77e Mon Sep 17 00:00:00 2001 From: Gautam Korlam Date: Fri, 14 Mar 2025 00:22:08 -0700 Subject: [PATCH 1/2] Add support to send in url directly --- src/transport/ws.rs | 112 +++++++++++++++++++++++++++----------------- 1 file changed, 68 insertions(+), 44 deletions(-) diff --git a/src/transport/ws.rs b/src/transport/ws.rs index c323b99..6623a04 100644 --- a/src/transport/ws.rs +++ b/src/transport/ws.rs @@ -21,12 +21,16 @@ use tokio_tungstenite::{ }, }; use tracing; +use url; + +/// Capture config to server/connect +enum WsTransportConfig { + Server { host: String, port: u16 }, + Client { url: String }, +} pub struct WebSocketTransport { - host: String, - port: u16, - client_mode: bool, - use_tls: bool, + config: WsTransportConfig, buffer_size: usize, auth_header: Option, } @@ -34,10 +38,7 @@ pub struct WebSocketTransport { impl WebSocketTransport { pub fn new_server(host: String, port: u16, buffer_size: usize) -> Self { Self { - host, - port, - client_mode: false, - use_tls: false, + config: WsTransportConfig::Server { host, port }, buffer_size, auth_header: None, } @@ -45,10 +46,9 @@ impl WebSocketTransport { pub fn new_client(host: String, port: u16, buffer_size: usize) -> Self { Self { - host, - port, - client_mode: true, - use_tls: false, + config: WsTransportConfig::Client { + url: format!("ws://{}:{}/ws", host, port), + }, buffer_size, auth_header: None, } @@ -56,10 +56,17 @@ impl WebSocketTransport { pub fn new_wss_client(host: String, port: u16, buffer_size: usize) -> Self { Self { - host, - port, - client_mode: true, - use_tls: true, + config: WsTransportConfig::Client { + url: format!("wss://{}:{}/ws", host, port), + }, + buffer_size, + auth_header: None, + } + } + + pub fn new_client_with_url(url: String, buffer_size: usize) -> Self { + Self { + config: WsTransportConfig::Client { url }, buffer_size, auth_header: None, } @@ -279,6 +286,7 @@ impl WebSocketTransport { Ok(addr) => addr, Err(e) => { tracing::error!("Failed to parse host address: {:?}", e); + message_task.abort(); return; } }; @@ -294,27 +302,40 @@ impl WebSocketTransport { } async fn run_client( - host: String, - port: u16, - use_tls: bool, + url: String, auth_header: Option, mut cmd_rx: mpsc::Receiver, event_tx: mpsc::Sender, ) { - let protocol = if use_tls { "wss" } else { "ws" }; - let ws_url = format!("{}://{}:{}/ws", protocol, host, port); - tracing::debug!("Connecting to WebSocket endpoint: {}", ws_url); + tracing::debug!("Connecting to WebSocket endpoint: {}", url); // Connect to the WebSocket server let ws_stream_result = if let Some(auth) = &auth_header { // Create a custom connector with auth header let request = http::Request::builder() - .uri(&ws_url) + .uri(&url) .header("User-Agent", "mcp-rs-client") .header("Authorization", auth) .header("Connection", "Upgrade") .header("Upgrade", "websocket") - .header("Host", format!("{}:{}", host, port)) + .header("Host", { + // Extract host:port from URL for the Host header + if let Ok(parsed_url) = url::Url::parse(&url) { + format!( + "{}:{}", + parsed_url.host_str().unwrap_or("localhost"), + parsed_url + .port() + .unwrap_or(if parsed_url.scheme() == "wss" { + 443 + } else { + 80 + }) + ) + } else { + "localhost:80".to_string() + } + }) .header("Sec-WebSocket-Version", "13") .header( "Sec-WebSocket-Key", @@ -344,7 +365,7 @@ impl WebSocketTransport { } } else { // Use standard connection without auth - connect_async(&ws_url).await + connect_async(&url).await }; // Connect to the WebSocket server @@ -506,22 +527,23 @@ impl Transport for WebSocketTransport { let (cmd_tx, cmd_rx) = mpsc::channel(self.buffer_size); let (event_tx, event_rx) = mpsc::channel(self.buffer_size); - if self.client_mode { - tokio::spawn(Self::run_client( - self.host.clone(), - self.port, - self.use_tls, - self.auth_header.clone(), - cmd_rx, - event_tx, - )); - } else { - tokio::spawn(Self::run_server( - self.host.clone(), - self.port, - cmd_rx, - event_tx, - )); + match &self.config { + WsTransportConfig::Client { url } => { + tokio::spawn(Self::run_client( + url.clone(), + self.auth_header.clone(), + cmd_rx, + event_tx, + )); + } + WsTransportConfig::Server { host, port } => { + tokio::spawn(Self::run_server( + host.clone(), + port.clone(), + cmd_rx, + event_tx, + )); + } } let event_rx = Arc::new(tokio::sync::Mutex::new(event_rx)); @@ -554,6 +576,7 @@ mod tests { let host = "127.0.0.1".to_string(); let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts let mut transport = WebSocketTransport::new_server(host.clone(), port, 32); + let ws_url = format!("ws://{}:{}/ws", host, port); // Start the transport let TransportChannels { cmd_tx, event_rx } = transport.start().await?; @@ -562,7 +585,6 @@ mod tests { sleep(Duration::from_millis(300)).await; // Connect a client to the server - let ws_url = format!("ws://{}:{}/ws", host, port); let (ws_stream, _) = connect_async(&ws_url).await.expect("Failed to connect"); let (mut write, mut read) = ws_stream.split(); @@ -673,6 +695,7 @@ mod tests { // Start a WebSocket server using warp for the client to connect to let host = "127.0.0.1".to_string(); let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts + let ws_url = format!("ws://{}:{}/ws", host, port); // Create a channel to receive messages from the test server let (server_tx, mut server_rx) = mpsc::channel::(32); @@ -711,7 +734,7 @@ mod tests { sleep(Duration::from_millis(100)).await; // Create and start the WebSocket client transport - let mut transport = WebSocketTransport::new_client(host.clone(), port, 32); + let mut transport = WebSocketTransport::new_client_with_url(ws_url, 32); let TransportChannels { cmd_tx, event_rx } = transport.start().await?; // Give the client time to connect @@ -784,6 +807,7 @@ mod tests { // Start a WebSocket server using warp for the client to connect to let host = "127.0.0.1".to_string(); let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts + let ws_url = format!("ws://{}:{}/ws", host, port); // Create a channel to receive messages from the test server let (server_tx, mut server_rx) = mpsc::channel::(32); @@ -835,7 +859,7 @@ mod tests { // Create and start the WebSocket client transport with auth header let auth_header = "Bearer test-token-123".to_string(); - let mut transport = WebSocketTransport::new_client(host.clone(), port, 32) + let mut transport = WebSocketTransport::new_client_with_url(ws_url, 32) .with_auth_header(auth_header.clone()); let TransportChannels { From bb5caf9469f270e1b6195fa87f70b878b5b01635 Mon Sep 17 00:00:00 2001 From: Gautam Korlam Date: Fri, 14 Mar 2025 00:23:59 -0700 Subject: [PATCH 2/2] clippy --- src/transport/ws.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/transport/ws.rs b/src/transport/ws.rs index 6623a04..bebcaff 100644 --- a/src/transport/ws.rs +++ b/src/transport/ws.rs @@ -537,12 +537,7 @@ impl Transport for WebSocketTransport { )); } WsTransportConfig::Server { host, port } => { - tokio::spawn(Self::run_server( - host.clone(), - port.clone(), - cmd_rx, - event_tx, - )); + tokio::spawn(Self::run_server(host.clone(), *port, cmd_rx, event_tx)); } }