diff --git a/src/archivers.rs b/src/archivers.rs index 931b5a2..45fd561 100644 --- a/src/archivers.rs +++ b/src/archivers.rs @@ -273,6 +273,12 @@ mod tests { ip: String::new(), port: 0, }, + robust_query: config::RobustQueryConfig { + enabled: false, + redundancy: 3, + max_retries: 5, + verbose_logs: false, + }, } } diff --git a/src/config.json b/src/config.json index d1600e5..813c500 100644 --- a/src/config.json +++ b/src/config.json @@ -36,5 +36,11 @@ "notifier": { "ip": "127.0.0.1", "port": 4701 + }, + "robust_query": { + "enabled": true, + "redundancy": 3, + "max_retries": 5, + "verbose_logs": true } } \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 292e7d9..577cc4a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -39,6 +39,8 @@ pub struct Config { pub local_source: LocalSource, pub notifier: NotifierConfig, + + pub robust_query: RobustQueryConfig, } #[derive(Debug, serde::Deserialize, Clone)] @@ -95,6 +97,21 @@ pub struct NotifierConfig { pub port: u16, } +#[derive(Debug, serde::Deserialize, Clone)] +pub struct RobustQueryConfig { + /// Master switch. When false, all GET validator routes use the old single-forward path. + pub enabled: bool, + /// Number of matching responses required to consider the result trustworthy. + /// Clamped to min(redundancy, available_nodes) at runtime. + pub redundancy: usize, + /// Maximum number of retry iterations if no consensus is reached in the first batch. + /// Each iteration queries (redundancy - highest_tally_count) additional nodes. + pub max_retries: usize, + /// When true, logs detailed information about robust query results including + /// node IDs, tally counts, and whether consensus was reached. + pub verbose_logs: bool, +} + /// Load the configuration from the config json file /// path is src/config.json impl Config { diff --git a/src/http.rs b/src/http.rs index d4bf6de..c13320f 100644 --- a/src/http.rs +++ b/src/http.rs @@ -128,7 +128,14 @@ where .await { Ok(Ok(())) => { - let (_method, route) = get_route(&req_buf).unwrap(); + let (_method, route) = match get_route(&req_buf) { + Some(r) => r, + None => { + eprintln!("Failed to parse HTTP method/route from request"); + respond_with_bad_request(&mut client_stream).await?; + continue; + } + }; match get_application(route.as_str()) { Application::Monitor => { @@ -161,14 +168,25 @@ where } } Application::Validator => { - if let Err(e) = liberdus::handle_request( - req_buf, - &mut client_stream, - liberdus.clone(), - config.clone(), - ) - .await - { + let handler_result = if _method == "GET" && config.robust_query.enabled { + liberdus::handle_request_robust( + req_buf, + &mut client_stream, + liberdus.clone(), + config.clone(), + ) + .await + } else { + liberdus::handle_request( + req_buf, + &mut client_stream, + liberdus.clone(), + config.clone(), + ) + .await + }; + + if let Err(e) = handler_result { eprintln!("Error handling validator request: {}", e); } continue; @@ -271,6 +289,14 @@ where client_stream.write_all(response.as_bytes()).await } +pub async fn respond_with_bad_request(client_stream: &mut S) -> Result<(), std::io::Error> +where + S: AsyncWrite + Unpin + Send, +{ + let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client_stream.write_all(response.as_bytes()).await +} + pub async fn respond_with_notfound(client_stream: &mut S) -> Result<(), std::io::Error> where S: AsyncWrite + Unpin + Send, diff --git a/src/lib.rs b/src/lib.rs index 0f06030..7e672f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub mod crypto; pub mod http; pub mod liberdus; pub mod notifier; +pub mod robust_query; pub mod rpc; pub mod shardus_monitor; pub mod subscription; diff --git a/src/liberdus.rs b/src/liberdus.rs index 44c258c..2c95f0d 100644 --- a/src/liberdus.rs +++ b/src/liberdus.rs @@ -12,7 +12,6 @@ use std::{ Arc, }, time::Duration, - u128, }; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::time::sleep; @@ -209,7 +208,7 @@ impl Liberdus { /// For a node with: /// - `timetaken_ms = 100` /// - `max_timeout = 500` - /// The bias is calculated as: + /// The bias is calculated as: /// ```math /// normalized_rtt = (100 - 0.01) / (500 - 0.01) ≈ 0.19996 /// bias = 1.0 - 0.19996 ≈ 0.80004 @@ -297,7 +296,7 @@ impl Liberdus { guard.clone() }; - let max_timeout = self.config.max_http_timeout_ms.try_into().unwrap_or(4000); // 3 seconds + let max_timeout = self.config.max_http_timeout_ms; // 3 seconds let mut sorted_nodes = nodes.as_ref().clone(); sorted_nodes.sort_by(|a, b| { @@ -433,6 +432,41 @@ impl Liberdus { } } + /// Picks up to `n` distinct consensors, excluding any whose ID is in `exclude`. + /// Uses the same selection strategy as get_next_appropriate_consensor + /// (round-robin initially, then biased random after prepare_list). + /// If fewer than `n` unique nodes are available, returns as many as possible. + pub async fn get_n_distinct_consensors( + &self, + n: usize, + exclude: &std::collections::HashSet, + ) -> Vec { + let mut result = Vec::with_capacity(n); + let mut seen = exclude.clone(); + let mut attempts = 0; + let max_attempts = n * 3; // safety cap to avoid infinite loop + + // Safety break if we can't find nodes + if self.active_nodelist.load().is_empty() { + return result; + } + + while result.len() < n && attempts < max_attempts { + attempts += 1; + match self.get_next_appropriate_consensor_with_retry(3).await { + Some((_, node)) => { + if seen.contains(&node.id) { + continue; + } + seen.insert(node.id.clone()); + result.push(node); + } + None => break, + } + } + result + } + pub fn set_consensor_trip_ms(&self, node_id: String, trip_ms: u128) { // list already prepared on the first round robin, no need to keep recording rtt for nodes if self @@ -789,6 +823,12 @@ mod tests { ip: String::new(), port: 0, }, + robust_query: config::RobustQueryConfig { + enabled: true, + redundancy: 3, + max_retries: 5, + verbose_logs: false, + }, } } @@ -1179,7 +1219,13 @@ mod tests { .await .expect_err("connection should fail"); let err_text = format!("{}", err); - assert!(err_text.contains("Connection refused") || err_text.contains("Error connecting")); + assert!( + err_text.contains("Connection refused") + || err_text.contains("Error connecting") + || err_text.contains("Timeout connecting"), + "unexpected error: {}", + err_text + ); let mut buf = vec![0u8; 128]; let n = tokio::time::timeout(std::time::Duration::from_secs(1), peer.read(&mut buf)) @@ -1348,6 +1394,57 @@ where Ok(()) } +/// Like handle_request, but queries multiple validators and returns +/// the consensus response. Used for read-only GET routes. +pub async fn handle_request_robust( + request_buffer: Vec, + client_stream: &mut S, + liberdus: Arc, + config: Arc, +) -> Result<(), Box> +where + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + let result = match crate::robust_query::robust_query(&liberdus, request_buffer, &config).await { + Ok(r) => r, + Err(e) => { + eprintln!("Robust query failed: {}", e); + let error_str = e.to_string(); + if error_str.contains("Timeout") + || error_str.contains("timeout") + || error_str.contains("Connection refused") + { + http::respond_with_timeout(client_stream).await?; + } else { + http::respond_with_internal_error(client_stream).await?; + } + return Err(e); + } + }; + + if !result.is_robust && config.robust_query.verbose_logs { + eprintln!("Warning: returning non-robust result (best-effort)"); + } + + // Set the same headers as handle_request does + let mut response_data = result.response_data; + http::set_http_header(&mut response_data, "Connection", "keep-alive"); + http::set_http_header( + &mut response_data, + "Keep-Alive", + format!("timeout={}", config.tcp_keepalive_time_sec).as_str(), + ); + http::set_http_header(&mut response_data, "Access-Control-Allow-Origin", "*"); + + if let Err(e) = client_stream.write_all(&response_data).await { + eprintln!("Error relaying robust response to client: {}", e); + http::respond_with_internal_error(client_stream).await?; + return Err(Box::new(e)); + } + + Ok(()) +} + /// Returns `(true, )` when `route` is exactly /// `/old_receipt/<64-char hex>` with no extra path segments. /// Otherwise returns `(false, String::new())`. diff --git a/src/notifier.rs b/src/notifier.rs index 1272b2b..b292094 100644 --- a/src/notifier.rs +++ b/src/notifier.rs @@ -128,8 +128,8 @@ mod tests { use tokio::net::TcpListener; use crate::config::{ - Config, LocalSource, NodeFilteringConfig, NotifierConfig, ShardusMonitorProxyConfig, - StandaloneNetworkConfig, TLSConfig, + Config, LocalSource, NodeFilteringConfig, NotifierConfig, RobustQueryConfig, + ShardusMonitorProxyConfig, StandaloneNetworkConfig, TLSConfig, }; fn test_config(port: u16) -> Config { @@ -172,6 +172,12 @@ mod tests { ip: "127.0.0.1".into(), port, }, + robust_query: RobustQueryConfig { + enabled: false, + redundancy: 3, + max_retries: 5, + verbose_logs: false, + }, } } diff --git a/src/robust_query.rs b/src/robust_query.rs new file mode 100644 index 0000000..429d912 --- /dev/null +++ b/src/robust_query.rs @@ -0,0 +1,355 @@ +//! Robust query mechanism for validator read routes. +//! +//! Queries multiple validators for the same request, tallies responses +//! for consensus, and returns the majority-agreed response. + +use crate::config::Config; +use crate::http; +use crate::liberdus::Liberdus; +use std::collections::HashSet; +use std::time::Duration; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpStream; +use tokio::time::timeout; + +/// One "bucket" in the tally. Tracks a unique response body, how many +/// nodes returned it, and which nodes those were. +#[derive(Debug)] +struct TallyItem { + /// The HTTP response body bytes (headers stripped). + body: Vec, + /// The full raw HTTP response (headers + body) from the first node + /// that produced this body. This is what we send back to the client + /// so that headers like Content-Type are preserved. + full_response: Vec, + /// How many nodes returned a response whose body matches `self.body`. + count: usize, + /// IDs of the nodes that returned this response. + node_ids: Vec, +} + +/// Collects responses and counts how many nodes agree on each unique body. +struct Tally { + /// The number of matching responses needed to declare a winner. + win_count: usize, + /// All distinct response buckets seen so far. + items: Vec, +} + +impl Tally { + fn new(win_count: usize) -> Self { + Tally { + win_count, + items: Vec::new(), + } + } + + /// Add a response. Returns `Some(&TallyItem)` if `win_count` is reached. + /// Comparison is done on body bytes only (headers stripped). + fn add( + &mut self, + body: Vec, + full_response: Vec, + node_id: String, + ) -> Option<&TallyItem> { + // Check for existing item + if let Some(index) = self.items.iter().position(|item| item.body == body) { + let item = &mut self.items[index]; + item.count += 1; + item.node_ids.push(node_id); + if item.count >= self.win_count { + return Some(item); + } + return None; + } + + // No match found, add new + let new_item = TallyItem { + body, + full_response, + count: 1, + node_ids: vec![node_id], + }; + self.items.push(new_item); + + // If win_count is 1, return the item we just pushed + if self.win_count <= 1 { + return self.items.last(); + } + + None + } + + /// Returns the current highest count across all items. + fn highest_count(&self) -> usize { + self.items.iter().map(|item| item.count).max().unwrap_or(0) + } + + /// Returns the TallyItem with the highest count (best-effort fallback). + fn highest_count_item(&self) -> Option<&TallyItem> { + self.items.iter().max_by_key(|item| item.count) + } +} + +/// The result of a robust query across multiple validators. +pub struct RobustQueryResult { + /// The full HTTP response bytes (headers + body) of the winning response. + /// This is what gets written back to the client stream. + pub response_data: Vec, + /// True if the result met the redundancy threshold. + /// False means we're returning best-effort (highest count). + pub is_robust: bool, +} + +/// Queries multiple validators for the same request and returns the +/// response that at least `redundancy` nodes agree on. +pub async fn robust_query( + liberdus: &Liberdus, + request_buffer: Vec, + config: &Config, +) -> Result> { + let active_nodes = liberdus.active_nodelist.load(); + let available_count = active_nodes.len(); + if available_count == 0 { + return Err("No nodes available".into()); + } + + // Clamp redundancy to min(config.redundancy, available_nodes). + let redundancy = std::cmp::min(config.robust_query.redundancy, available_count); + if redundancy == 0 { + return Err("Redundancy is 0".into()); + } + + let verbose = config.robust_query.verbose_logs; + + let mut tally = Tally::new(redundancy); + let mut used_ids = HashSet::new(); + let mut tries = 0; + + if verbose { + eprintln!( + "[robust_query] starting: redundancy={}, available_nodes={}, max_retries={}", + redundancy, available_count, config.robust_query.max_retries + ); + } + + // Loop until max retries or we run out of nodes to query + while tries < config.robust_query.max_retries { + tries += 1; + + let to_query = redundancy - tally.highest_count(); + + // Pick `to_query` distinct nodes not already used + let batch = liberdus + .get_n_distinct_consensors(to_query, &used_ids) + .await; + + if batch.is_empty() { + if verbose { + eprintln!( + "[robust_query] iteration {}: no new unique nodes available, breaking", + tries + ); + } + break; // ran out of nodes + } + + if verbose { + let batch_ids: Vec<&str> = batch.iter().map(|n| n.id.as_str()).collect(); + eprintln!( + "[robust_query] iteration {}: querying {} nodes: {:?}", + tries, + batch.len(), + batch_ids + ); + } + + for node in &batch { + used_ids.insert(node.id.clone()); + } + + // Query all nodes in this batch concurrently + let mut handles = Vec::new(); + for node in batch { + let buf = request_buffer.clone(); + let ip_port = format!("{}:{}", node.ip, node.port); + let node_id = node.id.clone(); + // RTT penalty value for connect/write errors (matches send() behavior) + let max_timeout = config.max_http_timeout_ms as u64; + + handles.push(tokio::spawn(async move { + type BoxErr = Box; + let now = std::time::Instant::now(); + + // TCP connect - 1s timeout (matches send()) + let mut server_stream = match timeout( + Duration::from_millis(1000), + TcpStream::connect(&ip_port), + ) + .await + { + Ok(Ok(s)) => s, + Ok(Err(e)) => { + let err: BoxErr = format!("Error connecting: {}", e).into(); + return (node_id, Err(err), max_timeout as u128); + } + Err(_) => { + let err: BoxErr = "Timeout connecting".into(); + return (node_id, Err(err), max_timeout as u128); + } + }; + + // Write request - 1s timeout (matches send()) + match timeout(Duration::from_millis(1000), server_stream.write_all(&buf)).await { + Ok(Ok(())) => {} // write succeeded + Ok(Err(e)) => { + let _ = server_stream.shutdown().await; + let err: BoxErr = format!("Error forwarding request: {}", e).into(); + return (node_id, Err(err), max_timeout as u128); + } + Err(_) => { + let _ = server_stream.shutdown().await; + let err: BoxErr = "Timeout forwarding request".into(); + return (node_id, Err(err), max_timeout as u128); + } + } + + // Collect response - 60s safety cap so we don't hang forever + let mut resp_data = Vec::new(); + match timeout( + Duration::from_secs(60), + http::collect_http(&mut server_stream, &mut resp_data), + ) + .await + { + Ok(Ok(())) => { + let elapsed_ms = now.elapsed().as_millis(); + let _ = server_stream.shutdown().await; + (node_id, Ok(resp_data), elapsed_ms) + } + Ok(Err(e)) => { + let _ = server_stream.shutdown().await; + let err: BoxErr = format!("Error reading response: {}", e).into(); + (node_id, Err(err), max_timeout as u128) + } + Err(_) => { + let _ = server_stream.shutdown().await; + let err: BoxErr = "Timeout reading response (60s)".into(); + (node_id, Err(err), max_timeout as u128) + } + } + })); + } + + // Wait for all results in this batch + let results = futures::future::join_all(handles).await; + + for (node_id, result, elapsed_ms) in results.into_iter().flatten() { + // Update RTT bias data for future node selection + liberdus.set_consensor_trip_ms(node_id.clone(), elapsed_ms); + + if let Ok(response_data) = result { + // Skip empty responses (matches send() behavior) + if response_data.is_empty() { + if verbose { + eprintln!( + "[robust_query] node {} returned empty response, skipping", + node_id + ); + } + continue; + } + let body = http::extract_body(&response_data); + if let Some(winner) = tally.add(body, response_data, node_id) { + if verbose { + eprintln!( + "[robust_query] consensus reached: count={}, node_ids={:?}", + winner.count, winner.node_ids + ); + } + return Ok(RobustQueryResult { + response_data: winner.full_response.clone(), + is_robust: true, + }); + } + } else if let Err(e) = result { + if verbose { + eprintln!("[robust_query] node {} failed: {}", node_id, e); + } + } + } + } + + // No consensus reached -- return best-effort + match tally.highest_count_item() { + Some(item) => { + if verbose { + eprintln!( + "[robust_query] no consensus, best-effort: count={}, node_ids={:?}, total_tries={}", + item.count, item.node_ids, tries + ); + } + Ok(RobustQueryResult { + response_data: item.full_response.clone(), + is_robust: false, + }) + } + None => { + if verbose { + eprintln!( + "[robust_query] no responses from any validator after {} tries", + tries + ); + } + Err("No responses from any validator".into()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tally_consensus_reached() { + let mut tally = Tally::new(3); + + // 1. Add first response + let res1 = tally.add(vec![1, 2, 3], vec![10, 1, 2, 3], "node1".into()); + assert!(res1.is_none()); + assert_eq!(tally.highest_count(), 1); + + // 2. Add second response (match) + let res2 = tally.add(vec![1, 2, 3], vec![10, 1, 2, 3], "node2".into()); + assert!(res2.is_none()); + assert_eq!(tally.highest_count(), 2); + + // 3. Add third response (match) -> Winner + let res3 = tally.add(vec![1, 2, 3], vec![10, 1, 2, 3], "node3".into()); + assert!(res3.is_some()); + assert_eq!(res3.unwrap().count, 3); + assert_eq!(tally.highest_count(), 3); + } + + #[test] + fn test_tally_no_consensus() { + let mut tally = Tally::new(3); + + tally.add(vec![1], vec![], "node1".into()); + tally.add(vec![2], vec![], "node2".into()); + tally.add(vec![1], vec![], "node3".into()); + + assert_eq!(tally.highest_count(), 2); // Body [1] has 2 + + let best = tally.highest_count_item(); + assert!(best.is_some()); + assert_eq!(best.unwrap().body, vec![1]); + } + + #[test] + fn test_tally_single_winner() { + let mut tally = Tally::new(1); + let res = tally.add(vec![1], vec![], "node1".into()); + assert!(res.is_some()); + } +} diff --git a/src/subscription.rs b/src/subscription.rs index 3bc9ef7..13a5b1c 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -431,6 +431,12 @@ pub(crate) mod tests { ip: String::new(), port: 0, }, + robust_query: crate::config::RobustQueryConfig { + enabled: false, + redundancy: 3, + max_retries: 5, + verbose_logs: false, + }, } } diff --git a/tests/consistency_test.rs b/tests/consistency_test.rs index 724aeed..cd7c00b 100644 --- a/tests/consistency_test.rs +++ b/tests/consistency_test.rs @@ -110,6 +110,12 @@ async fn test_nodelist_consistency_lockstep() { ip: String::new(), port: 0, }, + robust_query: config::RobustQueryConfig { + enabled: false, + redundancy: 3, + max_retries: 5, + verbose_logs: false, + }, }); config.max_http_timeout_ms = 500; config.nodelist_refresh_interval_sec = 1; diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 1929cb6..9c91755 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -46,6 +46,12 @@ fn test_config(port: u16) -> Arc { ip: String::new(), port: 0, }, + robust_query: config::RobustQueryConfig { + enabled: false, + redundancy: 3, + max_retries: 5, + verbose_logs: false, + }, }) }