diff --git a/benches/latency/endpoint.rs b/benches/latency/endpoint.rs index 32035f3..4faecb8 100644 --- a/benches/latency/endpoint.rs +++ b/benches/latency/endpoint.rs @@ -45,8 +45,21 @@ impl EndpointWithContext for TestEndpoint { .into_websocket("/"); Ok(Some(ws)) } +} + +impl TestEndpoint { + pub fn new(port: u16, payload: &'static str) -> Self { + Self { + connection_info: ConnectionInfo::new("127.0.0.1", port), + payload, + } + } - fn poll(&mut self, ws: &mut Self::Target, ctx: &mut TestContext) -> std::io::Result<()> { + pub fn poll( + &mut self, + ws: &mut >::Target, + ctx: &mut TestContext, + ) -> std::io::Result<()> { if ctx.wants_write { ws.send_text(true, Some(self.payload.as_bytes()))?; ctx.wants_write = false; @@ -59,12 +72,3 @@ impl EndpointWithContext for TestEndpoint { Ok(()) } } - -impl TestEndpoint { - pub fn new(port: u16, payload: &'static str) -> Self { - Self { - connection_info: ConnectionInfo::new("127.0.0.1", port), - payload, - } - } -} diff --git a/benches/latency/main.rs b/benches/latency/main.rs index 14218e2..18f3b86 100644 --- a/benches/latency/main.rs +++ b/benches/latency/main.rs @@ -62,13 +62,15 @@ fn boomnet_rtt_benchmark_io_service(c: &mut Criterion) { // setup io service let mut ctx = TestContext::new(); - let mut io_service = DirectSelector::new().unwrap().into_io_service_with_context(&mut ctx); + let mut io_service = DirectSelector::new().unwrap().into_io_service_with_context(); io_service.register(TestEndpoint::new(9003, MSG)).unwrap(); group.bench_function("boomnet_rtt_io_service", |b| { b.iter(|| { loop { - io_service.poll(&mut ctx).unwrap(); + io_service + .poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll(ws, ctx)) + .unwrap(); if ctx.processed == 100 { ctx.wants_write = true; ctx.processed = 0; diff --git a/examples/common/mod.rs b/examples/common/mod.rs index bf13a01..8f1690e 100644 --- a/examples/common/mod.rs +++ b/examples/common/mod.rs @@ -61,6 +61,44 @@ impl TradeEndpoint { )?; Ok(()) } + + #[inline] + #[allow(dead_code)] + pub fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, data) = frame? { + match self.id % 4 { + 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), + 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), + 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), + 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), + _ => {} + } + } + } + Ok(()) + } + + #[inline] + #[allow(dead_code)] + pub fn poll_ctx( + &mut self, + ws: &mut TlsWebsocket<::Stream>, + _ctx: &mut FeedContext, + ) -> io::Result<()> { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, data) = frame? { + match self.id % 4 { + 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), + 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), + 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), + 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), + _ => {} + } + } + } + Ok(()) + } } impl ConnectionInfoProvider for TradeEndpoint { @@ -85,22 +123,6 @@ impl TlsWebsocketEndpoint for TradeEndpoint { Ok(Some(ws)) } - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - for frame in ws.read_batch()? { - if let WebsocketFrame::Text(fin, data) = frame? { - match self.id % 4 { - 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), - 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), - 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), - 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), - _ => {} - } - } - } - Ok(()) - } - fn can_recreate(&mut self, reason: DisconnectReason) -> bool { warn!("connection disconnected: {reason}"); true @@ -125,20 +147,4 @@ impl TlsWebsocketEndpointWithContext for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket, _ctx: &mut FeedContext) -> io::Result<()> { - for frame in ws.read_batch()? { - if let WebsocketFrame::Text(fin, data) = frame? { - match self.id % 4 { - 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), - 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), - 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), - 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), - _ => {} - } - } - } - Ok(()) - } } diff --git a/examples/endpoint_with_timer.rs b/examples/endpoint_with_timer.rs index 66d5410..bba3618 100644 --- a/examples/endpoint_with_timer.rs +++ b/examples/endpoint_with_timer.rs @@ -30,6 +30,23 @@ impl TradeEndpoint { next_disconnect_time_ns: ctx.current_time_ns() + Duration::from_secs(10).as_nanos() as u64, } } + + #[inline] + fn poll( + &mut self, + ws: &mut TlsWebsocket<>::Stream>, + ctx: &mut FeedContext, + ) -> io::Result<()> { + while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { + info!("({fin}) {}", String::from_utf8_lossy(data)); + } + let now_ns = ctx.current_time_ns(); + if now_ns > self.next_disconnect_time_ns { + self.next_disconnect_time_ns = now_ns + Duration::from_secs(10).as_nanos() as u64; + return Err(io::Error::other("disconnected due to timer")); + } + Ok(()) + } } #[derive(Debug)] @@ -75,19 +92,6 @@ impl TlsWebsocketEndpointWithContext for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket, ctx: &mut FeedContext) -> io::Result<()> { - while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { - info!("({fin}) {}", String::from_utf8_lossy(data)); - } - let now_ns = ctx.current_time_ns(); - if now_ns > self.next_disconnect_time_ns { - self.next_disconnect_time_ns = now_ns + Duration::from_secs(10).as_nanos() as u64; - return Err(io::Error::other("disconnected due to timer")); - } - Ok(()) - } } fn main() -> anyhow::Result<()> { @@ -95,13 +99,12 @@ fn main() -> anyhow::Result<()> { let mut ctx = FeedContext::new(); - let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut ctx); + let mut io_service = MioSelector::new()?.into_io_service_with_context(); let endpoint_btc = TradeEndpoint::new("wss://stream1.binance.com:443/ws", "btcusdt", &ctx); io_service.register(endpoint_btc)?; - loop { - io_service.poll(&mut ctx)?; + io_service.poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll(ws, ctx))?; } } diff --git a/examples/io_service_dispatch.rs b/examples/io_service_dispatch.rs index 634801b..333e6ed 100644 --- a/examples/io_service_dispatch.rs +++ b/examples/io_service_dispatch.rs @@ -23,11 +23,11 @@ fn main() -> anyhow::Result<()> { if success { break; } else { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/io_service_with_async_dns.rs b/examples/io_service_with_async_dns.rs index 57a5a18..d773d6c 100644 --- a/examples/io_service_with_async_dns.rs +++ b/examples/io_service_with_async_dns.rs @@ -9,10 +9,10 @@ mod common; fn main() -> anyhow::Result<()> { env_logger::init(); - let mut context = FeedContext::new(); + let mut ctx = FeedContext::new(); let mut io_service = MioSelector::new()? - .into_io_service_with_context(&mut context) + .into_io_service_with_context() .with_dns_resolver(AsyncDnsResolver::new()?); let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); @@ -24,6 +24,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll(&mut context)?; + io_service.poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll_ctx(ws, ctx))?; } } diff --git a/examples/io_service_with_auto_disconnect.rs b/examples/io_service_with_auto_disconnect.rs index 309f444..9d94e4f 100644 --- a/examples/io_service_with_auto_disconnect.rs +++ b/examples/io_service_with_auto_disconnect.rs @@ -24,6 +24,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_btc_2)?; loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/io_service_with_context.rs b/examples/io_service_with_context.rs index 1403020..f68fbc9 100644 --- a/examples/io_service_with_context.rs +++ b/examples/io_service_with_context.rs @@ -8,9 +8,9 @@ mod common; fn main() -> anyhow::Result<()> { env_logger::init(); - let mut context = FeedContext::new(); + let mut ctx = FeedContext::new(); - let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut context); + let mut io_service = MioSelector::new()?.into_io_service_with_context(); let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); let endpoint_eth = TradeEndpoint::new(1, "wss://stream2.binance.com:443/ws", None, "ethusdt"); @@ -21,6 +21,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll(&mut context)?; + io_service.poll(&mut ctx, |ws, ctx, endpoint| endpoint.poll_ctx(ws, ctx))?; } } diff --git a/examples/io_service_with_direct_selector.rs b/examples/io_service_with_direct_selector.rs index 10b436c..f1d608e 100644 --- a/examples/io_service_with_direct_selector.rs +++ b/examples/io_service_with_direct_selector.rs @@ -33,6 +33,16 @@ impl TradeEndpoint { ws_endpoint, } } + + #[inline] + fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + for frame in ws.read_batch()? { + if let WebsocketFrame::Text(fin, data) = frame? { + println!("[{}] ({fin}) {}", self.id, String::from_utf8_lossy(data)) + } + } + Ok(()) + } } impl ConnectionInfoProvider for TradeEndpoint { @@ -57,16 +67,6 @@ impl TlsWebsocketEndpoint for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - for frame in ws.read_batch()? { - if let WebsocketFrame::Text(fin, data) = frame? { - println!("[{}] ({fin}) {}", self.id, String::from_utf8_lossy(data)) - } - } - Ok(()) - } } fn main() -> anyhow::Result<()> { @@ -83,6 +83,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/io_service_without_context.rs b/examples/io_service_without_context.rs index 5623620..319e206 100644 --- a/examples/io_service_without_context.rs +++ b/examples/io_service_without_context.rs @@ -19,6 +19,6 @@ fn main() -> anyhow::Result<()> { io_service.register(endpoint_xrp)?; loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/examples/polymorphic_endpoints.rs b/examples/polymorphic_endpoints.rs index e884d4a..f4a2b75 100644 --- a/examples/polymorphic_endpoints.rs +++ b/examples/polymorphic_endpoints.rs @@ -22,6 +22,15 @@ enum MarketDataEndpoint { Ticker(TickerEndpoint), } +impl MarketDataEndpoint { + fn poll(&mut self, ws: &mut Websocket::Stream>>) -> io::Result<()> { + match self { + MarketDataEndpoint::Ticker(ticker) => ticker.poll(ws), + MarketDataEndpoint::Trade(trade) => trade.poll(ws), + } + } +} + impl ConnectionInfoProvider for MarketDataEndpoint { fn connection_info(&self) -> &ConnectionInfo { match self { @@ -40,13 +49,6 @@ impl TlsWebsocketEndpoint for MarketDataEndpoint { MarketDataEndpoint::Trade(trade) => trade.create_websocket(addr), } } - - fn poll(&mut self, ws: &mut Websocket>) -> io::Result<()> { - match self { - MarketDataEndpoint::Ticker(ticker) => TlsWebsocketEndpoint::poll(ticker, ws), - MarketDataEndpoint::Trade(trade) => TlsWebsocketEndpoint::poll(trade, ws), - } - } } struct TradeEndpoint { @@ -64,6 +66,14 @@ impl TradeEndpoint { instrument, } } + + #[inline] + fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { + info!("({fin}) {}", String::from_utf8_lossy(data)); + } + Ok(()) + } } impl ConnectionInfoProvider for TradeEndpoint { @@ -90,14 +100,6 @@ impl TlsWebsocketEndpoint for TradeEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { - info!("({fin}) {}", String::from_utf8_lossy(data)); - } - Ok(()) - } } struct TickerEndpoint { @@ -115,6 +117,14 @@ impl TickerEndpoint { instrument, } } + + #[inline] + fn poll(&mut self, ws: &mut TlsWebsocket<::Stream>) -> io::Result<()> { + while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { + info!("({fin}) {}", String::from_utf8_lossy(data)); + } + Ok(()) + } } impl ConnectionInfoProvider for TickerEndpoint { @@ -141,14 +151,6 @@ impl TlsWebsocketEndpoint for TickerEndpoint { Ok(Some(ws)) } - - #[inline] - fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { - while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { - info!("({fin}) {}", String::from_utf8_lossy(data)); - } - Ok(()) - } } fn main() -> anyhow::Result<()> { @@ -163,6 +165,6 @@ fn main() -> anyhow::Result<()> { io_service.register(trade); loop { - io_service.poll()?; + io_service.poll(|ws, endpoint| endpoint.poll(ws))?; } } diff --git a/src/service/endpoint.rs b/src/service/endpoint.rs index 0da9de7..2b0c27c 100644 --- a/src/service/endpoint.rs +++ b/src/service/endpoint.rs @@ -16,8 +16,8 @@ pub trait Endpoint: ConnectionInfoProvider { /// await the next connection attempt with (possibly) different `addr`. fn create_target(&mut self, addr: SocketAddr) -> io::Result>; - /// Called by the `IOService` on each duty cycle. - fn poll(&mut self, target: &mut Self::Target) -> io::Result<()>; + // /// Called by the `IOService` on each duty cycle. + // fn poll(&mut self, target: &mut Self::Target) -> io::Result<()>; /// Upon disconnection `IOService` will query the endpoint if the connection can be /// recreated, passing the disconnect `reason`. If not, it will cause program to panic. @@ -48,9 +48,6 @@ pub trait EndpointWithContext: ConnectionInfoProvider { /// return `Ok(None)` and await the next connection attempt with (possibly) different `addr`. fn create_target(&mut self, addr: SocketAddr, context: &mut C) -> io::Result>; - /// Called by the `IOService` on each duty cycle passing user provided `Context`. - fn poll(&mut self, target: &mut Self::Target, context: &mut C) -> io::Result<()>; - /// Upon disconnection `IOService` will query the endpoint if the connection can be /// recreated, passing the disconnect `reason`. If not, it will cause program to panic. fn can_recreate(&mut self, _reason: DisconnectReason, _context: &mut C) -> bool { @@ -116,8 +113,6 @@ pub mod ws { fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>>>; - fn poll(&mut self, ws: &mut Websocket>) -> io::Result<()>; - fn can_recreate(&mut self, _reason: DisconnectReason) -> bool { true } @@ -138,11 +133,6 @@ pub mod ws { self.create_websocket(addr) } - #[inline] - fn poll(&mut self, target: &mut Self::Target) -> io::Result<()> { - self.poll(target) - } - #[inline] fn can_recreate(&mut self, reason: DisconnectReason) -> bool { self.can_recreate(reason) @@ -163,8 +153,6 @@ pub mod ws { ctx: &mut C, ) -> io::Result>>>; - fn poll(&mut self, ws: &mut Websocket>, ctx: &mut C) -> io::Result<()>; - fn can_recreate(&mut self, _reason: DisconnectReason, _ctx: &mut C) -> bool { true } @@ -185,11 +173,6 @@ pub mod ws { self.create_websocket(addr, context) } - #[inline] - fn poll(&mut self, target: &mut Self::Target, context: &mut C) -> io::Result<()> { - self.poll(target, context) - } - #[inline] fn can_recreate(&mut self, reason: DisconnectReason, context: &mut C) -> bool { self.can_recreate(reason, context) diff --git a/src/service/mod.rs b/src/service/mod.rs index 5deb6e9..12421c2 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -22,8 +22,6 @@ pub mod time; const ENDPOINT_CREATION_THROTTLE_NS: u64 = Duration::from_secs(1).as_nanos() as u64; -const DNS_RESOLVE_TIMEOUT_NS: u64 = Duration::from_secs(5).as_nanos() as u64; - /// Endpoint handle. #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] #[repr(transparent)] @@ -40,6 +38,7 @@ pub struct IOService { auto_disconnect: Option Duration>>, time_source: TS, dns_resolver: D, + dns_query_timeout_ns: Option, } /// Defines how an instance that implements `SelectService` can be transformed @@ -54,10 +53,7 @@ pub trait IntoIOService { /// Defines how an instance that implements [`Selector`] can be transformed /// into an [`IOService`] with [`Context`], facilitating the management of asynchronous I/O operations. pub trait IntoIOServiceWithContext { - fn into_io_service_with_context( - self, - context: &mut C, - ) -> IOService + fn into_io_service_with_context(self) -> IOService where Self: Selector, Self: Sized; @@ -75,6 +71,7 @@ impl IOService { auto_disconnect: None, time_source, dns_resolver, + dns_query_timeout_ns: None, } } @@ -94,6 +91,15 @@ impl IOService { } } + /// Specify DNS query timeout. This is only relevant when using asynchronous form of + /// [`DnsResolver`]. + pub fn with_dns_query_timeout(self, timeout: Duration) -> IOService { + Self { + dns_query_timeout_ns: Some(timeout.as_nanos() as u64), + ..self + } + } + /// Specify custom [`TimeSource`] instead of the default system time source. pub fn with_time_source(self, time_source: T) -> IOService { IOService { @@ -105,6 +111,7 @@ impl IOService { next_endpoint_create_time_ns: self.next_endpoint_create_time_ns, selector: self.selector, dns_resolver: self.dns_resolver, + dns_query_timeout_ns: self.dns_query_timeout_ns, } } @@ -119,6 +126,7 @@ impl IOService { next_endpoint_create_time_ns: self.next_endpoint_create_time_ns, selector: self.selector, dns_resolver, + dns_query_timeout_ns: self.dns_query_timeout_ns, } } @@ -190,9 +198,12 @@ impl IOService { where TS: TimeSource, { - let now = self.time_source.current_time_nanos(); - if now > created_time_ns + DNS_RESOLVE_TIMEOUT_NS { - return Err(io::Error::new(ErrorKind::TimedOut, "dns resolution timed out")); + // check if dns query resolution timed out + if let Some(dns_query_timeout) = self.dns_query_timeout_ns { + let now = self.time_source.current_time_nanos(); + if now > created_time_ns + dns_query_timeout { + return Err(io::Error::new(ErrorKind::TimedOut, "dns resolution timed out")); + } } match query.poll() { Ok(addrs) => { @@ -206,6 +217,42 @@ impl IOService { Err(err) => Err(err), } } + + #[cold] + fn check_pending_endpoints(&mut self, create_target: F) -> io::Result<()> + where + E: ConnectionInfoProvider, + TS: TimeSource, + F: FnOnce(&mut E, SocketAddr) -> io::Result::Target>>, + { + let current_time_ns = self.time_source.current_time_nanos(); + if current_time_ns > self.next_endpoint_create_time_ns { + if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() { + if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? { + match create_target(&mut endpoint, addr)? { + Some(stream) => { + let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect()); + let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr); + self.selector.register(handle.0, &mut io_node)?; + self.io_nodes.insert(handle.0, io_node); + } + None => { + // request new dns query + let info = endpoint.connection_info(); + let query = self.dns_resolver.new_query(info.host(), info.port())?; + let now = self.time_source.current_time_nanos(); + self.pending_endpoints.push_back((handle, query, now, endpoint)) + } + } + } else { + self.pending_endpoints + .push_back((handle, query, query_time_ns, endpoint)) + } + } + self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; + } + Ok(()) + } } impl IOService @@ -217,37 +264,16 @@ where { /// This method polls all registered endpoints for readiness and performs I/O operations based /// on the ['Selector'] poll results. It then iterates through all endpoints, either - /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] + /// updating existing streams or creating and registering new ones. If there's pending IO on the stream, + /// the provided `action` closure will be invoked. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). - pub fn poll(&mut self) -> io::Result<()> { + pub fn poll(&mut self, mut action: F) -> io::Result<()> + where + F: FnMut(&mut E::Target, &mut E) -> io::Result<()>, + { // check for pending endpoints (one at a time & throttled) if !self.pending_endpoints.is_empty() { - let current_time_ns = self.time_source.current_time_nanos(); - if current_time_ns > self.next_endpoint_create_time_ns { - if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() { - if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? { - match endpoint.create_target(addr)? { - Some(stream) => { - let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect()); - let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr); - self.selector.register(handle.0, &mut io_node)?; - self.io_nodes.insert(handle.0, io_node); - } - None => { - // request new dns query - let info = endpoint.connection_info(); - let query = self.dns_resolver.new_query(info.host(), info.port())?; - let now = self.time_source.current_time_nanos(); - self.pending_endpoints.push_back((handle, query, now, endpoint)) - } - } - } else { - self.pending_endpoints - .push_back((handle, query, query_time_ns, endpoint)) - } - } - self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; - } + self.check_pending_endpoints(|endpoint, addr| endpoint.create_target(addr))?; } // check for readiness events @@ -285,8 +311,8 @@ where // poll endpoints self.io_nodes.retain(|_token, io_node| { - let (stream, (_, endpoint)) = io_node.as_parts_mut(); - if let Err(err) = endpoint.poll(stream) { + let (target, (_, endpoint)) = io_node.as_parts_mut(); + if let Err(err) = action(target, endpoint) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); if endpoint.can_recreate(DisconnectReason::other(err)) { @@ -333,37 +359,16 @@ where { /// This method polls all registered endpoints for readiness passing the [`Context`] and performs I/O operations based /// on the `SelectService` poll results. It then iterates through all endpoints, either - /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] + /// updating existing streams or creating and registering new ones. If there's pending IO on the stream, + /// the provided `action` closure will be invoked. It uses [`Endpoint::can_recreate`] /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). - pub fn poll(&mut self, context: &mut C) -> io::Result<()> { + pub fn poll(&mut self, ctx: &mut C, mut action: F) -> io::Result<()> + where + F: FnMut(&mut E::Target, &mut C, &mut E) -> io::Result<()>, + { // check for pending endpoints (one at a time & throttled) if !self.pending_endpoints.is_empty() { - let current_time_ns = self.time_source.current_time_nanos(); - if current_time_ns > self.next_endpoint_create_time_ns { - if let Some((handle, mut query, query_time_ns, mut endpoint)) = self.pending_endpoints.pop_front() { - if let Some(addr) = self.resolve_dns(&mut query, query_time_ns)? { - match endpoint.create_target(addr, context)? { - Some(stream) => { - let ttl = self.auto_disconnect.as_ref().map(|auto_disconnect| auto_disconnect()); - let mut io_node = IONode::new(stream, handle, endpoint, ttl, &self.time_source, addr); - self.selector.register(handle.0, &mut io_node)?; - self.io_nodes.insert(handle.0, io_node); - } - None => { - // request new dns query - let info = endpoint.connection_info(); - let query = self.dns_resolver.new_query(info.host(), info.port())?; - let now = self.time_source.current_time_nanos(); - self.pending_endpoints.push_back((handle, query, now, endpoint)) - } - } - } else { - self.pending_endpoints - .push_back((handle, query, query_time_ns, endpoint)) - } - } - self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; - } + self.check_pending_endpoints(|endpoint, addr| endpoint.create_target(addr, ctx))?; } // check for readiness events @@ -376,10 +381,10 @@ where let force_disconnect = current_time_ns > io_node.disconnect_time_ns; if force_disconnect { // check if we really have to disconnect - return if io_node.as_endpoint_mut().1.can_auto_disconnect(context) { + return if io_node.as_endpoint_mut().1.can_auto_disconnect(ctx) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); - if endpoint.can_recreate(DisconnectReason::auto_disconnect(io_node.ttl), context) { + if endpoint.can_recreate(DisconnectReason::auto_disconnect(io_node.ttl), ctx) { let info = endpoint.connection_info(); let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap(); let now = self.time_source.current_time_nanos(); @@ -401,11 +406,11 @@ where // poll endpoints self.io_nodes.retain(|_token, io_node| { - let (stream, (_, endpoint)) = io_node.as_parts_mut(); - if let Err(err) = endpoint.poll(stream, context) { + let (target, (_, endpoint)) = io_node.as_parts_mut(); + if let Err(err) = action(target, ctx, endpoint) { self.selector.unregister(io_node).unwrap(); let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); - if endpoint.can_recreate(DisconnectReason::other(err), context) { + if endpoint.can_recreate(DisconnectReason::other(err), ctx) { let info = endpoint.connection_info(); let query = self.dns_resolver.new_query(info.host(), info.port()).unwrap(); let now = self.time_source.current_time_nanos(); diff --git a/src/service/select/direct.rs b/src/service/select/direct.rs index f05d553..e16dcd8 100644 --- a/src/service/select/direct.rs +++ b/src/service/select/direct.rs @@ -60,10 +60,7 @@ impl IntoIOService for DirectSelector { } impl> IntoIOServiceWithContext for DirectSelector { - fn into_io_service_with_context( - self, - _ctx: &mut C, - ) -> IOService + fn into_io_service_with_context(self) -> IOService where Self: Selector, Self: Sized, diff --git a/src/service/select/mio.rs b/src/service/select/mio.rs index 6f3fdb2..31ea53c 100644 --- a/src/service/select/mio.rs +++ b/src/service/select/mio.rs @@ -86,10 +86,7 @@ impl IntoIOService for MioSelector { } impl> IntoIOServiceWithContext for MioSelector { - fn into_io_service_with_context( - self, - _ctx: &mut C, - ) -> IOService + fn into_io_service_with_context(self) -> IOService where Self: Selector, Self: Sized, diff --git a/src/ws/handshake.rs b/src/ws/handshake.rs index 7f4df93..26d17bf 100644 --- a/src/ws/handshake.rs +++ b/src/ws/handshake.rs @@ -80,7 +80,8 @@ impl Handshaker { let mut response = Response::new(&mut headers); response.parse(self.inbound_buffer.view()).map_err(io::Error::other)?; if response.code.unwrap() != StatusCode::SWITCHING_PROTOCOLS.as_u16() { - return Err(io::Error::other("unable to switch protocols")); + let reason = response.reason.unwrap_or_default(); + return Err(io::Error::other(format!("unable to switch protocols, reason: {}", reason))); } self.state = Completed; }