Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 63 additions & 44 deletions src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,52 @@ use tokio_tungstenite::{
},
};
use tracing;
use url;

/// Capture config to server/connect
enum WsTransportConfig {
Server { host: String, port: u16 },
Client { url: String },
}

pub struct WebSocketTransport {
host: String,
port: u16,
client_mode: bool,
use_tls: bool,
config: WsTransportConfig,
buffer_size: usize,
auth_header: Option<String>,
}

impl WebSocketTransport {
pub fn new_server(host: String, port: u16, buffer_size: usize) -> Self {
Self {
host,
port,
client_mode: false,
use_tls: false,
config: WsTransportConfig::Server { host, port },
buffer_size,
auth_header: None,
}
}

pub fn new_client(host: String, port: u16, buffer_size: usize) -> Self {
Self {
host,
port,
client_mode: true,
use_tls: false,
config: WsTransportConfig::Client {
url: format!("ws://{}:{}/ws", host, port),
},
buffer_size,
auth_header: None,
}
}

pub fn new_wss_client(host: String, port: u16, buffer_size: usize) -> Self {
Self {
host,
port,
client_mode: true,
use_tls: true,
config: WsTransportConfig::Client {
url: format!("wss://{}:{}/ws", host, port),
},
buffer_size,
auth_header: None,
}
}

pub fn new_client_with_url(url: String, buffer_size: usize) -> Self {
Self {
config: WsTransportConfig::Client { url },
buffer_size,
auth_header: None,
}
Expand Down Expand Up @@ -279,6 +286,7 @@ impl WebSocketTransport {
Ok(addr) => addr,
Err(e) => {
tracing::error!("Failed to parse host address: {:?}", e);
message_task.abort();
return;
}
};
Expand All @@ -294,27 +302,40 @@ impl WebSocketTransport {
}

async fn run_client(
host: String,
port: u16,
use_tls: bool,
url: String,
auth_header: Option<String>,
mut cmd_rx: mpsc::Receiver<TransportCommand>,
event_tx: mpsc::Sender<TransportEvent>,
) {
let protocol = if use_tls { "wss" } else { "ws" };
let ws_url = format!("{}://{}:{}/ws", protocol, host, port);
tracing::debug!("Connecting to WebSocket endpoint: {}", ws_url);
tracing::debug!("Connecting to WebSocket endpoint: {}", url);

// Connect to the WebSocket server
let ws_stream_result = if let Some(auth) = &auth_header {
// Create a custom connector with auth header
let request = http::Request::builder()
.uri(&ws_url)
.uri(&url)
.header("User-Agent", "mcp-rs-client")
.header("Authorization", auth)
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Host", format!("{}:{}", host, port))
.header("Host", {
// Extract host:port from URL for the Host header
if let Ok(parsed_url) = url::Url::parse(&url) {
format!(
"{}:{}",
parsed_url.host_str().unwrap_or("localhost"),
parsed_url
.port()
.unwrap_or(if parsed_url.scheme() == "wss" {
443
} else {
80
})
)
} else {
"localhost:80".to_string()
}
})
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
Expand Down Expand Up @@ -344,7 +365,7 @@ impl WebSocketTransport {
}
} else {
// Use standard connection without auth
connect_async(&ws_url).await
connect_async(&url).await
};

// Connect to the WebSocket server
Expand Down Expand Up @@ -506,22 +527,18 @@ impl Transport for WebSocketTransport {
let (cmd_tx, cmd_rx) = mpsc::channel(self.buffer_size);
let (event_tx, event_rx) = mpsc::channel(self.buffer_size);

if self.client_mode {
tokio::spawn(Self::run_client(
self.host.clone(),
self.port,
self.use_tls,
self.auth_header.clone(),
cmd_rx,
event_tx,
));
} else {
tokio::spawn(Self::run_server(
self.host.clone(),
self.port,
cmd_rx,
event_tx,
));
match &self.config {
WsTransportConfig::Client { url } => {
tokio::spawn(Self::run_client(
url.clone(),
self.auth_header.clone(),
cmd_rx,
event_tx,
));
}
WsTransportConfig::Server { host, port } => {
tokio::spawn(Self::run_server(host.clone(), *port, cmd_rx, event_tx));
}
}

let event_rx = Arc::new(tokio::sync::Mutex::new(event_rx));
Expand Down Expand Up @@ -554,6 +571,7 @@ mod tests {
let host = "127.0.0.1".to_string();
let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts
let mut transport = WebSocketTransport::new_server(host.clone(), port, 32);
let ws_url = format!("ws://{}:{}/ws", host, port);

// Start the transport
let TransportChannels { cmd_tx, event_rx } = transport.start().await?;
Expand All @@ -562,7 +580,6 @@ mod tests {
sleep(Duration::from_millis(300)).await;

// Connect a client to the server
let ws_url = format!("ws://{}:{}/ws", host, port);
let (ws_stream, _) = connect_async(&ws_url).await.expect("Failed to connect");
let (mut write, mut read) = ws_stream.split();

Expand Down Expand Up @@ -673,6 +690,7 @@ mod tests {
// Start a WebSocket server using warp for the client to connect to
let host = "127.0.0.1".to_string();
let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts
let ws_url = format!("ws://{}:{}/ws", host, port);

// Create a channel to receive messages from the test server
let (server_tx, mut server_rx) = mpsc::channel::<JsonRpcMessage>(32);
Expand Down Expand Up @@ -711,7 +729,7 @@ mod tests {
sleep(Duration::from_millis(100)).await;

// Create and start the WebSocket client transport
let mut transport = WebSocketTransport::new_client(host.clone(), port, 32);
let mut transport = WebSocketTransport::new_client_with_url(ws_url, 32);
let TransportChannels { cmd_tx, event_rx } = transport.start().await?;

// Give the client time to connect
Expand Down Expand Up @@ -784,6 +802,7 @@ mod tests {
// Start a WebSocket server using warp for the client to connect to
let host = "127.0.0.1".to_string();
let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts
let ws_url = format!("ws://{}:{}/ws", host, port);

// Create a channel to receive messages from the test server
let (server_tx, mut server_rx) = mpsc::channel::<JsonRpcMessage>(32);
Expand Down Expand Up @@ -835,7 +854,7 @@ mod tests {

// Create and start the WebSocket client transport with auth header
let auth_header = "Bearer test-token-123".to_string();
let mut transport = WebSocketTransport::new_client(host.clone(), port, 32)
let mut transport = WebSocketTransport::new_client_with_url(ws_url, 32)
.with_auth_header(auth_header.clone());

let TransportChannels {
Expand Down