From 8ffd70ddfd497dee73c7c9f645a8469ff5faddc3 Mon Sep 17 00:00:00 2001 From: Tom Brzozowski Date: Sun, 12 Oct 2025 18:35:13 +0100 Subject: [PATCH 1/6] use poll action --- benches/latency/endpoint.rs | 24 ++++--- benches/latency/main.rs | 6 +- examples/common/mod.rs | 70 +++++++++++---------- examples/endpoint_with_timer.rs | 35 ++++++----- examples/io_service_dispatch.rs | 4 +- examples/io_service_with_async_dns.rs | 6 +- examples/io_service_with_auto_disconnect.rs | 2 +- examples/io_service_with_context.rs | 6 +- examples/io_service_with_direct_selector.rs | 22 +++---- examples/io_service_without_context.rs | 2 +- examples/polymorphic_endpoints.rs | 50 ++++++++------- src/service/endpoint.rs | 21 +------ src/service/mod.rs | 23 ++++--- src/service/select/direct.rs | 5 +- src/service/select/mio.rs | 5 +- 15 files changed, 139 insertions(+), 142 deletions(-) diff --git a/benches/latency/endpoint.rs b/benches/latency/endpoint.rs index 32035f3..4faecb8 100644 --- a/benches/latency/endpoint.rs +++ b/benches/latency/endpoint.rs @@ -45,8 +45,21 @@ impl EndpointWithContext for TestEndpoint { .into_websocket("/"); Ok(Some(ws)) } +} + +impl TestEndpoint { + pub fn new(port: u16, payload: &'static str) -> Self { + Self { + connection_info: ConnectionInfo::new("127.0.0.1", port), + payload, + } + } - fn poll(&mut self, ws: &mut Self::Target, ctx: &mut TestContext) -> std::io::Result<()> { + pub fn poll( + &mut self, + ws: &mut >::Target, + ctx: &mut TestContext, + ) -> std::io::Result<()> { if ctx.wants_write { ws.send_text(true, Some(self.payload.as_bytes()))?; ctx.wants_write = false; @@ -59,12 +72,3 @@ impl EndpointWithContext for TestEndpoint { Ok(()) } } - -impl TestEndpoint { - pub fn new(port: u16, payload: &'static str) -> Self { - Self { - connection_info: ConnectionInfo::new("127.0.0.1", port), - payload, - } - } -} diff --git a/benches/latency/main.rs b/benches/latency/main.rs index 14218e2..18f3b86 100644 --- a/benches/latency/main.rs +++ b/benches/latency/main.rs @@ -62,13 +62,15 @@ fn boomnet_rtt_benchmark_io_service(c: &mut Criterion) { // setup io service let mut ctx = TestContext::new(); - let mut io_service = DirectSelector::new().unwrap().into_io_service_with_context(&mut ctx); + let mut io_service = DirectSelector::new().unwrap().into_io_service_with_context(); io_service.register(TestEndpoint::new(9003, MSG)).unwrap(); group.bench_function("boomnet_rtt_io_service", |b| { b.iter(|| { loop { - io_service.poll(&mut ctx).unwrap(); + io_service + .poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll(ws, ctx)) + .unwrap(); if ctx.processed == 100 { ctx.wants_write = true; ctx.processed = 0; diff --git a/examples/common/mod.rs b/examples/common/mod.rs index bf13a01..8f1690e 100644 --- a/examples/common/mod.rs +++ b/examples/common/mod.rs @@ -61,6 +61,44 @@ impl TradeEndpoint { )?; Ok(()) } + + #[inline] + #[allow(dead_code)] + pub fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, data) = frame? { + match self.id % 4 { + 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), + 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), + 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), + 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), + _ => {} + } + } + } + Ok(()) + } + + #[inline] + #[allow(dead_code)] + pub fn poll_ctx( + &mut self, + ws: &mut TlsWebsocket<::Stream>, + _ctx: &mut FeedContext, + ) -> io::Result<()> { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, data) = frame? { + match self.id % 4 { + 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), + 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), + 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), + 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), + _ => {} + } + } + } + Ok(()) + } } impl ConnectionInfoProvider for TradeEndpoint { @@ -85,22 +123,6 @@ impl TlsWebsocketEndpoint for TradeEndpoint { Ok(Some(ws)) } - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - for frame in ws.read_batch()? { - if let WebsocketFrame::Text(fin, data) = frame? { - match self.id % 4 { - 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), - 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), - 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), - 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), - _ => {} - } - } - } - Ok(()) - } - fn can_recreate(&mut self, reason: DisconnectReason) -> bool { warn!("connection disconnected: {reason}"); true @@ -125,20 +147,4 @@ impl TlsWebsocketEndpointWithContext for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket, _ctx: &mut FeedContext) -> io::Result<()> { - for frame in ws.read_batch()? { - if let WebsocketFrame::Text(fin, data) = frame? { - match self.id % 4 { - 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), - 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), - 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), - 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), - _ => {} - } - } - } - Ok(()) - } } diff --git a/examples/endpoint_with_timer.rs b/examples/endpoint_with_timer.rs index 66d5410..bba3618 100644 --- a/examples/endpoint_with_timer.rs +++ b/examples/endpoint_with_timer.rs @@ -30,6 +30,23 @@ impl TradeEndpoint { next_disconnect_time_ns: ctx.current_time_ns() + Duration::from_secs(10).as_nanos() as u64, } } + + #[inline] + fn poll( + &mut self, + ws: &mut TlsWebsocket<>::Stream>, + ctx: &mut FeedContext, + ) -> io::Result<()> { + while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { + info!("({fin}) {}", String::from_utf8_lossy(data)); + } + let now_ns = ctx.current_time_ns(); + if now_ns > self.next_disconnect_time_ns { + self.next_disconnect_time_ns = now_ns + Duration::from_secs(10).as_nanos() as u64; + return Err(io::Error::other("disconnected due to timer")); + } + Ok(()) + } } #[derive(Debug)] @@ -75,19 +92,6 @@ impl TlsWebsocketEndpointWithContext for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket, ctx: &mut FeedContext) -> io::Result<()> { - while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { - info!("({fin}) {}", String::from_utf8_lossy(data)); - } - let now_ns = ctx.current_time_ns(); - if now_ns > self.next_disconnect_time_ns { - self.next_disconnect_time_ns = now_ns + Duration::from_secs(10).as_nanos() as u64; - return Err(io::Error::other("disconnected due to timer")); - } - Ok(()) - } } fn main() -> anyhow::Result<()> { @@ -95,13 +99,12 @@ fn main() -> anyhow::Result<()> { let mut ctx = FeedContext::new(); - let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut ctx); + let mut io_service = MioSelector::new()?.into_io_service_with_context(); let endpoint_btc = TradeEndpoint::new("wss://stream1.binance.com:443/ws", "btcusdt", &ctx); io_service.register(endpoint_btc)?; - loop { - io_service.poll(&mut ctx)?; + io_service.poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll(ws, ctx))?; } } diff --git a/examples/io_service_dispatch.rs b/examples/io_service_dispatch.rs index 634801b..333e6ed 100644 --- a/examples/io_service_dispatch.rs +++ b/examples/io_service_dispatch.rs @@ -23,11 +23,11 @@ fn main() -> anyhow::Result<()> { if success { break; } else { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/io_service_with_async_dns.rs b/examples/io_service_with_async_dns.rs index 57a5a18..d773d6c 100644 --- a/examples/io_service_with_async_dns.rs +++ b/examples/io_service_with_async_dns.rs @@ -9,10 +9,10 @@ mod common; fn main() -> anyhow::Result<()> { env_logger::init(); - let mut context = FeedContext::new(); + let mut ctx = FeedContext::new(); let mut io_service = MioSelector::new()? - .into_io_service_with_context(&mut context) + .into_io_service_with_context() .with_dns_resolver(AsyncDnsResolver::new()?); let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); @@ -24,6 +24,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll(&mut context)?; + io_service.poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll_ctx(ws, ctx))?; } } diff --git a/examples/io_service_with_auto_disconnect.rs b/examples/io_service_with_auto_disconnect.rs index 309f444..9d94e4f 100644 --- a/examples/io_service_with_auto_disconnect.rs +++ b/examples/io_service_with_auto_disconnect.rs @@ -24,6 +24,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_btc_2)?; loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/io_service_with_context.rs b/examples/io_service_with_context.rs index 1403020..f68fbc9 100644 --- a/examples/io_service_with_context.rs +++ b/examples/io_service_with_context.rs @@ -8,9 +8,9 @@ mod common; fn main() -> anyhow::Result<()> { env_logger::init(); - let mut context = FeedContext::new(); + let mut ctx = FeedContext::new(); - let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut context); + let mut io_service = MioSelector::new()?.into_io_service_with_context(); let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); let endpoint_eth = TradeEndpoint::new(1, "wss://stream2.binance.com:443/ws", None, "ethusdt"); @@ -21,6 +21,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll(&mut context)?; + io_service.poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll_ctx(ws, ctx))?; } } diff --git a/examples/io_service_with_direct_selector.rs b/examples/io_service_with_direct_selector.rs index 10b436c..f1d608e 100644 --- a/examples/io_service_with_direct_selector.rs +++ b/examples/io_service_with_direct_selector.rs @@ -33,6 +33,16 @@ impl TradeEndpoint { ws_endpoint, } } + + #[inline] + fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, data) = frame? { + println!("[{}] ({fin}) {}", self.id, String::from_utf8_lossy(data)) + } + } + Ok(()) + } } impl ConnectionInfoProvider for TradeEndpoint { @@ -57,16 +67,6 @@ impl TlsWebsocketEndpoint for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - for frame in ws.read_batch()? { - if let WebsocketFrame::Text(fin, data) = frame? { - println!("[{}] ({fin}) {}", self.id, String::from_utf8_lossy(data)) - } - } - Ok(()) - } } fn main() -> anyhow::Result<()> { @@ -83,6 +83,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/io_service_without_context.rs b/examples/io_service_without_context.rs index 5623620..319e206 100644 --- a/examples/io_service_without_context.rs +++ b/examples/io_service_without_context.rs @@ -19,6 +19,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/polymorphic_endpoints.rs b/examples/polymorphic_endpoints.rs index e884d4a..f4a2b75 100644 --- a/examples/polymorphic_endpoints.rs +++ b/examples/polymorphic_endpoints.rs @@ -22,6 +22,15 @@ enum MarketDataEndpoint { Ticker(TickerEndpoint), } +impl MarketDataEndpoint { + fn poll(&mut self, ws: &mut Websocket::Stream>>) -> io::Result<()> { + match self { + MarketDataEndpoint::Ticker(ticker) => ticker.poll(ws), + MarketDataEndpoint::Trade(trade) => trade.poll(ws), + } + } +} + impl ConnectionInfoProvider for MarketDataEndpoint { fn connection_info(&self) -> &ConnectionInfo { match self { @@ -40,13 +49,6 @@ impl TlsWebsocketEndpoint for MarketDataEndpoint { MarketDataEndpoint::Trade(trade) => trade.create_websocket(addr), } } - - fn poll(&mut self, ws: &mut Websocket>) -> io::Result<()> { - match self { - MarketDataEndpoint::Ticker(ticker) => TlsWebsocketEndpoint::poll(ticker, ws), - MarketDataEndpoint::Trade(trade) => TlsWebsocketEndpoint::poll(trade, ws), - } - } } struct TradeEndpoint { @@ -64,6 +66,14 @@ impl TradeEndpoint { instrument, } } + + #[inline] + fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { + info!("({fin}) {}", String::from_utf8_lossy(data)); + } + Ok(()) + } } impl ConnectionInfoProvider for TradeEndpoint { @@ -90,14 +100,6 @@ impl TlsWebsocketEndpoint for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { - info!("({fin}) {}", String::from_utf8_lossy(data)); - } - Ok(()) - } } struct TickerEndpoint { @@ -115,6 +117,14 @@ impl TickerEndpoint { instrument, } } + + #[inline] + fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { + info!("({fin}) {}", String::from_utf8_lossy(data)); + } + Ok(()) + } } impl ConnectionInfoProvider for TickerEndpoint { @@ -141,14 +151,6 @@ impl TlsWebsocketEndpoint for TickerEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { - info!("({fin}) {}", String::from_utf8_lossy(data)); - } - Ok(()) - } } fn main() -> anyhow::Result<()> { @@ -163,6 +165,6 @@ fn main() -> anyhow::Result<()> { io_service.register(trade); loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/src/service/endpoint.rs b/src/service/endpoint.rs index 0da9de7..2b0c27c 100644 --- a/src/service/endpoint.rs +++ b/src/service/endpoint.rs @@ -16,8 +16,8 @@ pub trait Endpoint: ConnectionInfoProvider { /// await the next connection attempt with (possibly) different `addr`. fn create_target(&mut self, addr: SocketAddr) -> io::Result>; - /// Called by the `IOService` on each duty cycle. - fn poll(&mut self, target: &mut Self::Target) -> io::Result<()>; + // /// Called by the `IOService` on each duty cycle. + // fn poll(&mut self, target: &mut Self::Target) -> io::Result<()>; /// Upon disconnection `IOService` will query the endpoint if the connection can be /// recreated, passing the disconnect `reason`. If not, it will cause program to panic. @@ -48,9 +48,6 @@ pub trait EndpointWithContext: ConnectionInfoProvider { /// return `Ok(None)` and await the next connection attempt with (possibly) different `addr`. fn create_target(&mut self, addr: SocketAddr, context: &mut C) -> io::Result>; - /// Called by the `IOService` on each duty cycle passing user provided `Context`. - fn poll(&mut self, target: &mut Self::Target, context: &mut C) -> io::Result<()>; - /// Upon disconnection `IOService` will query the endpoint if the connection can be /// recreated, passing the disconnect `reason`. If not, it will cause program to panic. fn can_recreate(&mut self, _reason: DisconnectReason, _context: &mut C) -> bool { @@ -116,8 +113,6 @@ pub mod ws { fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>>>; - fn poll(&mut self, ws: &mut Websocket>) -> io::Result<()>; - fn can_recreate(&mut self, _reason: DisconnectReason) -> bool { true } @@ -138,11 +133,6 @@ pub mod ws { self.create_websocket(addr) } - #[inline] - fn poll(&mut self, target: &mut Self::Target) -> io::Result<()> { - self.poll(target) - } - #[inline] fn can_recreate(&mut self, reason: DisconnectReason) -> bool { self.can_recreate(reason) @@ -163,8 +153,6 @@ pub mod ws { ctx: &mut C, ) -> io::Result>>>; - fn poll(&mut self, ws: &mut Websocket>, ctx: &mut C) -> io::Result<()>; - fn can_recreate(&mut self, _reason: DisconnectReason, _ctx: &mut C) -> bool { true } @@ -185,11 +173,6 @@ pub mod ws { self.create_websocket(addr, context) } - #[inline] - fn poll(&mut self, target: &mut Self::Target, context: &mut C) -> io::Result<()> { - self.poll(target, context) - } - #[inline] fn can_recreate(&mut self, reason: DisconnectReason, context: &mut C) -> bool { self.can_recreate(reason, context) diff --git a/src/service/mod.rs b/src/service/mod.rs index 5deb6e9..aef27e5 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -54,10 +54,7 @@ pub trait IntoIOService { /// Defines how an instance that implements [`Selector`] can be transformed /// into an [`IOService`] with [`Context`], facilitating the management of asynchronous I/O operations. pub trait IntoIOServiceWithContext { - fn into_io_service_with_context( - self, - context: &mut C, - ) -> IOService + fn into_io_service_with_context(self) -> IOService where Self: Selector, Self: Sized; @@ -219,7 +216,10 @@ where /// on the ['Selector'] poll results. It then iterates through all endpoints, either /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). - pub fn poll(&mut self) -> io::Result<()> { + pub fn poll(&mut self, mut action: F) -> io::Result<()> + where + F: FnMut(&mut E::Target, &mut E) -> io::Result<()>, + { // check for pending endpoints (one at a time & throttled) if !self.pending_endpoints.is_empty() { let current_time_ns = self.time_source.current_time_nanos(); @@ -285,8 +285,8 @@ where // poll endpoints self.io_nodes.retain(|_token, io_node| { - let (stream, (_, endpoint)) = io_node.as_parts_mut(); - if let Err(err) = endpoint.poll(stream) { + let (target, (_, endpoint)) = io_node.as_parts_mut(); + if let Err(err) = action(target, endpoint) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); if endpoint.can_recreate(DisconnectReason::other(err)) { @@ -335,7 +335,10 @@ where /// on the `SelectService` poll results. It then iterates through all endpoints, either /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). - pub fn poll(&mut self, context: &mut C) -> io::Result<()> { + pub fn poll(&mut self, context: &mut C, mut action: F) -> io::Result<()> + where + F: FnMut(&mut E::Target, &mut C, &mut E) -> io::Result<()>, + { // check for pending endpoints (one at a time & throttled) if !self.pending_endpoints.is_empty() { let current_time_ns = self.time_source.current_time_nanos(); @@ -401,8 +404,8 @@ where // poll endpoints self.io_nodes.retain(|_token, io_node| { - let (stream, (_, endpoint)) = io_node.as_parts_mut(); - if let Err(err) = endpoint.poll(stream, context) { + let (target, (_, endpoint)) = io_node.as_parts_mut(); + if let Err(err) = action(target, context, endpoint) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); if endpoint.can_recreate(DisconnectReason::other(err), context) { diff --git a/src/service/select/direct.rs b/src/service/select/direct.rs index f05d553..e16dcd8 100644 --- a/src/service/select/direct.rs +++ b/src/service/select/direct.rs @@ -60,10 +60,7 @@ impl IntoIOService for DirectSelector { } impl> IntoIOServiceWithContext for DirectSelector { - fn into_io_service_with_context( - self, - _ctx: &mut C, - ) -> IOService + fn into_io_service_with_context(self) -> IOService where Self: Selector, Self: Sized, diff --git a/src/service/select/mio.rs b/src/service/select/mio.rs index 6f3fdb2..31ea53c 100644 --- a/src/service/select/mio.rs +++ b/src/service/select/mio.rs @@ -86,10 +86,7 @@ impl IntoIOService for MioSelector { } impl> IntoIOServiceWithContext for MioSelector { - fn into_io_service_with_context( - self, - _ctx: &mut C, - ) -> IOService + fn into_io_service_with_context(self) -> IOService where Self: Selector, Self: Sized, From e2a32702ff2f65d7f4fb81c3f1d44a805de12576 Mon Sep 17 00:00:00 2001 From: Tom Brzozowski Date: Thu, 16 Oct 2025 12:07:05 +0100 Subject: [PATCH 2/6] configure dns timeout --- src/service/mod.rs | 24 +++++++++++++++++++----- src/ws/handshake.rs | 3 ++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/service/mod.rs b/src/service/mod.rs index aef27e5..da76b8c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -22,8 +22,6 @@ pub mod time; const ENDPOINT_CREATION_THROTTLE_NS: u64 = Duration::from_secs(1).as_nanos() as u64; -const DNS_RESOLVE_TIMEOUT_NS: u64 = Duration::from_secs(5).as_nanos() as u64; - /// Endpoint handle. #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] #[repr(transparent)] @@ -40,6 +38,7 @@ pub struct IOService { auto_disconnect: Option Duration>>, time_source: TS, dns_resolver: D, + dns_query_timeout_ns: Option, } /// Defines how an instance that implements `SelectService` can be transformed @@ -72,6 +71,7 @@ impl IOService { auto_disconnect: None, time_source, dns_resolver, + dns_query_timeout_ns: None, } } @@ -91,6 +91,15 @@ impl IOService { } } + /// Specify DNS query timeout. This is only relevant when using asynchronous form of + /// [`DnsResolver`]. + pub fn with_dns_query_timeout(self, timeout: Duration) -> IOService { + Self { + dns_query_timeout_ns: Some(timeout.as_nanos() as u64), + ..self + } + } + /// Specify custom [`TimeSource`] instead of the default system time source. pub fn with_time_source(self, time_source: T) -> IOService { IOService { @@ -102,6 +111,7 @@ impl IOService { next_endpoint_create_time_ns: self.next_endpoint_create_time_ns, selector: self.selector, dns_resolver: self.dns_resolver, + dns_query_timeout_ns: self.dns_query_timeout_ns, } } @@ -116,6 +126,7 @@ impl IOService { next_endpoint_create_time_ns: self.next_endpoint_create_time_ns, selector: self.selector, dns_resolver, + dns_query_timeout_ns: self.dns_query_timeout_ns, } } @@ -187,9 +198,12 @@ impl IOService { where TS: TimeSource, { - let now = self.time_source.current_time_nanos(); - if now > created_time_ns + DNS_RESOLVE_TIMEOUT_NS { - return Err(io::Error::new(ErrorKind::TimedOut, "dns resolution timed out")); + // check if dns query resolution timed out + if let Some(dns_query_timeout) = self.dns_query_timeout_ns { + let now = self.time_source.current_time_nanos(); + if now > created_time_ns + dns_query_timeout { + return Err(io::Error::new(ErrorKind::TimedOut, "dns resolution timed out")); + } } match query.poll() { Ok(addrs) => { diff --git a/src/ws/handshake.rs b/src/ws/handshake.rs index 7f4df93..26d17bf 100644 --- a/src/ws/handshake.rs +++ b/src/ws/handshake.rs @@ -80,7 +80,8 @@ impl Handshaker { let mut response = Response::new(&mut headers); response.parse(self.inbound_buffer.view()).map_err(io::Error::other)?; if response.code.unwrap() != StatusCode::SWITCHING_PROTOCOLS.as_u16() { - return Err(io::Error::other("unable to switch protocols")); + let reason = response.reason.unwrap_or_default(); + return Err(io::Error::other(format!("unable to switch protocols, reason: {}", reason))); } self.state = Completed; } From 9a9211e5f895e4a0719cbbb4f6a3f30d09996145 Mon Sep 17 00:00:00 2001 From: Tom Brzozowski Date: Thu, 16 Oct 2025 13:01:25 +0100 Subject: [PATCH 3/6] refactor check for pending endpoints --- src/service/mod.rs | 100 +++++++++++++++++++-------------------------- 1 file changed, 43 insertions(+), 57 deletions(-) diff --git a/src/service/mod.rs b/src/service/mod.rs index da76b8c..0747be4 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -217,6 +217,42 @@ impl IOService { Err(err) => Err(err), } } + + #[cold] + fn check_pending_endpoints(&mut self, create_target: F) -> io::Result<()> + where + E: ConnectionInfoProvider, + TS: TimeSource, + F: FnOnce(&mut E, SocketAddr) -> io::Result::Target>>, + { + let current_time_ns = self.time_source.current_time_nanos(); + if current_time_ns > self.next_endpoint_create_time_ns { + if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() { + if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? { + match create_target(&mut endpoint, addr)? { + Some(stream) => { + let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect()); + let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr); + self.selector.register(handle.0, &mut io_node)?; + self.io_nodes.insert(handle.0, io_node); + } + None => { + // request new dns query + let info = endpoint.connection_info(); + let query = self.dns_resolver.new_query(info.host(), info.port())?; + let now = self.time_source.current_time_nanos(); + self.pending_endpoints.push_back((handle, query, now, endpoint)) + } + } + } else { + self.pending_endpoints + .push_back((handle, query, query_time_ns, endpoint)) + } + } + self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; + } + Ok(()) + } } impl IOService @@ -236,32 +272,7 @@ where { // check for pending endpoints (one at a time & throttled) if !self.pending_endpoints.is_empty() { - let current_time_ns = self.time_source.current_time_nanos(); - if current_time_ns > self.next_endpoint_create_time_ns { - if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() { - if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? { - match endpoint.create_target(addr)? { - Some(stream) => { - let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect()); - let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr); - self.selector.register(handle.0, &mut io_node)?; - self.io_nodes.insert(handle.0, io_node); - } - None => { - // request new dns query - let info = endpoint.connection_info(); - let query = self.dns_resolver.new_query(info.host(), info.port())?; - let now = self.time_source.current_time_nanos(); - self.pending_endpoints.push_back((handle, query, now, endpoint)) - } - } - } else { - self.pending_endpoints - .push_back((handle, query, query_time_ns, endpoint)) - } - } - self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; - } + self.check_pending_endpoints(|endpoint, addr| endpoint.create_target(addr))?; } // check for readiness events @@ -349,38 +360,13 @@ where /// on the `SelectService` poll results. It then iterates through all endpoints, either /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). - pub fn poll(&mut self, context: &mut C, mut action: F) -> io::Result<()> + pub fn poll(&mut self, ctx: &mut C, mut action: F) -> io::Result<()> where F: FnMut(&mut E::Target, &mut C, &mut E) -> io::Result<()>, { // check for pending endpoints (one at a time & throttled) if !self.pending_endpoints.is_empty() { - let current_time_ns = self.time_source.current_time_nanos(); - if current_time_ns > self.next_endpoint_create_time_ns { - if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() { - if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? { - match endpoint.create_target(addr, context)? { - Some(stream) => { - let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect()); - let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr); - self.selector.register(handle.0, &mut io_node)?; - self.io_nodes.insert(handle.0, io_node); - } - None => { - // request new dns query - let info = endpoint.connection_info(); - let query = self.dns_resolver.new_query(info.host(), info.port())?; - let now = self.time_source.current_time_nanos(); - self.pending_endpoints.push_back((handle, query, now, endpoint)) - } - } - } else { - self.pending_endpoints - .push_back((handle, query, query_time_ns, endpoint)) - } - } - self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; - } + self.check_pending_endpoints(|endpoint, addr| endpoint.create_target(addr, ctx))?; } // check for readiness events @@ -393,10 +379,10 @@ where let force_disconnect = current_time_ns > io_node.disconnect_time_ns; if force_disconnect { // check if we really have to disconnect - return if io_node.as_endpoint_mut().1.can_auto_disconnect(context) { + return if io_node.as_endpoint_mut().1.can_auto_disconnect(ctx) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); - if endpoint.can_recreate(DisconnectReason::auto_disconnect(io_node.ttl), context) { + if endpoint.can_recreate(DisconnectReason::auto_disconnect(io_node.ttl), ctx) { let info = endpoint.connection_info(); let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap(); let now = self.time_source.current_time_nanos(); @@ -419,10 +405,10 @@ where // poll endpoints self.io_nodes.retain(|_token, io_node| { let (target, (_, endpoint)) = io_node.as_parts_mut(); - if let Err(err) = action(target, context, endpoint) { + if let Err(err) = action(target, ctx, endpoint) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); - if endpoint.can_recreate(DisconnectReason::other(err), context) { + if endpoint.can_recreate(DisconnectReason::other(err), ctx) { let info = endpoint.connection_info(); let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap(); let now = self.time_source.current_time_nanos(); From 650f9749939f05cc00ec5f716fd948fe23262105 Mon Sep 17 00:00:00 2001 From: Tom Brzozowski Date: Thu, 16 Oct 2025 13:06:31 +0100 Subject: [PATCH 4/6] refactor check for pending endpoints --- src/service/mod.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/service/mod.rs b/src/service/mod.rs index 0747be4..12421c2 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -264,7 +264,8 @@ where { /// This method polls all registered endpoints for readiness and performs I/O operations based /// on the ['Selector'] poll results. It then iterates through all endpoints, either - /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] + /// updating existing streams or creating and registering new ones. If there's pending IO on the stream, + /// the provided `action` closure will be invoked. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). pub fn poll(&mut self, mut action: F) -> io::Result<()> where @@ -358,7 +359,8 @@ where { /// This method polls all registered endpoints for readiness passing the [`Context`] and performs I/O operations based /// on the `SelectService` poll results. It then iterates through all endpoints, either - /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] + /// updating existing streams or creating and registering new ones. If there's pending IO on the stream, + /// the provided `action` closure will be invoked. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). pub fn poll(&mut self, ctx: &mut C, mut action: F) -> io::Result<()> where From 26f597d3c0c94a7c0ddb9f04bbd5599e96b5a376 Mon Sep 17 00:00:00 2001 From: Tom Brzozowski Date: Thu, 16 Oct 2025 13:06:59 +0100 Subject: [PATCH 5/6] version bump --- Cargo.toml | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 85e9ded..26fff16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "boomnet" -version = "0.0.70" +version = "0.0.71" edition = "2024" license = "MIT" description = "Framework for building low latency clients on top of TCP." diff --git a/README.md b/README.md index 2ac2bfc..4c44507 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ particularly focusing on TCP stream-oriented clients that utilise various protoc Simply declare dependency on `boomnet` in your `Cargo.toml` and select desired [features](#features). ```toml [dependencies] -boomnet = { version = "0.0.70", features = ["rustls-webpki", "ws", "ext"]} +boomnet = { version = "0.0.71", features = ["rustls-webpki", "ws", "ext"]} ``` ## Design Principles From 24220565ba6c2b6a01decd83bfaba7c6d9daf1fe Mon Sep 17 00:00:00 2001 From: Tom Brzozowski Date: Thu, 16 Oct 2025 13:24:17 +0100 Subject: [PATCH 6/6] fix --- Cargo.toml | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 26fff16..85e9ded 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "boomnet" -version = "0.0.71" +version = "0.0.70" edition = "2024" license = "MIT" description = "Framework for building low latency clients on top of TCP." diff --git a/README.md b/README.md index 4c44507..2ac2bfc 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ particularly focusing on TCP stream-oriented clients that utilise various protoc Simply declare dependency on `boomnet` in your `Cargo.toml` and select desired [features](#features). ```toml [dependencies] -boomnet = { version = "0.0.71", features = ["rustls-webpki", "ws", "ext"]} +boomnet = { version = "0.0.70", features = ["rustls-webpki", "ws", "ext"]} ``` ## Design Principles