Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
rustdoc::broken_intra_doc_links
)]

#[cfg(feature = "stream")]
use std::pin::Pin;
use std::{
io,
marker::PhantomData,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
task::{Context, Poll},
time::Duration,
};
#[cfg(feature = "stream")]
use std::{pin::Pin, task::ready};

#[cfg(feature = "stream")]
use futures_core::Stream;
Expand Down Expand Up @@ -88,6 +88,8 @@ impl DualstackPinger {
DualstackMeasureManyStream {
v4: self.v4.measure_many(addresses_v4),
v6: self.v6.measure_many(addresses_v6),
v4_done: false,
v6_done: false,
}
}
}
Expand All @@ -105,16 +107,30 @@ impl DualstackPinger {
pub struct DualstackMeasureManyStream<'a, I: Iterator<Item = IpAddr>> {
v4: MeasureManyStream<'a, Ipv4Addr, FilterIpAddr<I, Ipv4Addr>>,
v6: MeasureManyStream<'a, Ipv6Addr, FilterIpAddr<I, Ipv6Addr>>,
v4_done: bool,
v6_done: bool,
}

impl<I: Iterator<Item = IpAddr>> DualstackMeasureManyStream<'_, I> {
pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(IpAddr, Duration)> {
if let Poll::Ready((v4, rtt)) = self.v4.poll_next_unpin(cx) {
return Poll::Ready((IpAddr::V4(v4), rtt));
pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(IpAddr, Duration)>> {
if !self.v4_done {
match self.v4.poll_next_unpin(cx) {
Poll::Ready(Some((v4, rtt))) => return Poll::Ready(Some((IpAddr::V4(v4), rtt))),
Poll::Ready(None) => self.v4_done = true,
Poll::Pending => {}
}
}

if !self.v6_done {
match self.v6.poll_next_unpin(cx) {
Poll::Ready(Some((v6, rtt))) => return Poll::Ready(Some((IpAddr::V6(v6), rtt))),
Poll::Ready(None) => self.v6_done = true,
Poll::Pending => {}
}
}

if let Poll::Ready((v6, rtt)) = self.v6.poll_next_unpin(cx) {
return Poll::Ready((IpAddr::V6(v6), rtt));
if self.v4_done && self.v6_done {
return Poll::Ready(None);
}

Poll::Pending
Expand All @@ -126,8 +142,7 @@ impl<I: Iterator<Item = IpAddr> + Unpin> Stream for DualstackMeasureManyStream<'
type Item = (IpAddr, Duration);

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let result = ready!(self.as_mut().poll_next_unpin(cx));
Poll::Ready(Some(result))
self.as_mut().poll_next_unpin(cx)
}
}

Expand Down
20 changes: 12 additions & 8 deletions src/pinger.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "stream")]
use std::pin::Pin;
use std::{
collections::HashMap,
future::poll_fn,
Expand All @@ -11,8 +13,6 @@ use std::{
task::{Context, Poll},
time::Duration,
};
#[cfg(feature = "stream")]
use std::{pin::Pin, task::ready};

use bytes::BytesMut;
#[cfg(feature = "stream")]
Expand Down Expand Up @@ -210,15 +210,20 @@ pub struct MeasureManyStream<'a, V: IpVersion, I: Iterator<Item = V>> {
}

impl<V: IpVersion, I: Iterator<Item = V>> MeasureManyStream<'_, V, I> {
pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(V, Duration)> {
// Try to see if another `MeasureManyStream` got it
if let Poll::Ready(Some((addr, rtt))) = self.poll_next_from_different_round(cx) {
return Poll::Ready((addr, rtt));
pub fn poll_next_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Option<(V, Duration)>> {
// Try to receive a response (may be from a different round)
if let Poll::Ready(maybe_reply) = self.poll_next_from_different_round(cx) {
return Poll::Ready(maybe_reply);
}

// Try to send ICMP echo requests
self.poll_next_icmp_replies(cx);

// Check if we're done: no more addresses to send AND no responses pending
if self.send_queue.peek().is_none() && self.in_flight.is_empty() {
return Poll::Ready(None);
}

Poll::Pending
}

Expand Down Expand Up @@ -269,8 +274,7 @@ impl<V: IpVersion, I: Iterator<Item = V> + Unpin> Stream for MeasureManyStream<'
type Item = (V, Duration);

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let result = ready!(self.as_mut().poll_next_unpin(cx));
Poll::Ready(Some(result))
self.as_mut().poll_next_unpin(cx)
}
}

Expand Down
116 changes: 116 additions & 0 deletions tests/stream_termination.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//! Regression tests for stream termination.
//!
//! The `MeasureManyStream` must properly terminate (return `None`) when all
//! ping requests have been sent and all responses have been received. Without
//! this, `while let Some(...) = stream.next().await` loops hang forever.

use std::{net::IpAddr, time::Duration};

use futures_util::StreamExt;
use massping::DualstackPinger;
use tokio::time;

/// Test that the stream properly terminates after receiving all responses.
///
/// This is a regression test for the bug where `poll_next_unpin` always
/// returned `Poll::Pending` after processing all results, causing
/// `while let` loops to hang indefinitely.
#[tokio::test(flavor = "current_thread")]
async fn stream_terminates_after_single_ping() {
let pinger = DualstackPinger::new().unwrap();
let localhost: IpAddr = "127.0.0.1".parse().unwrap();
let mut stream = pinger.measure_many([localhost].into_iter());

let mut count = 0;

// This should complete - not hang forever
let result = time::timeout(Duration::from_secs(5), async {
while let Some((addr, rtt)) = stream.next().await {
assert_eq!(addr, localhost);
assert!(rtt < Duration::from_secs(1), "RTT too high: {rtt:?}");
count += 1;
}
})
.await;

assert!(
result.is_ok(),
"stream did not terminate - hung in while let loop"
);
assert_eq!(count, 1, "expected exactly 1 ping response");
}

/// Test that the stream properly terminates after receiving multiple responses.
///
/// Note: We use different addresses because `in_flight` is keyed by address,
/// so pinging the same address multiple times in one `measure_many` call
/// would overwrite the previous entry.
#[tokio::test(flavor = "current_thread")]
async fn stream_terminates_after_multiple_pings() {
let pinger = DualstackPinger::new().unwrap();

// Use different loopback addresses (127.0.0.x all route to localhost)
let addresses: Vec<IpAddr> = vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
"127.0.0.3".parse().unwrap(),
];
let mut stream = pinger.measure_many(addresses.iter().copied());

let mut count = 0;

let result = time::timeout(Duration::from_secs(5), async {
while let Some((_addr, rtt)) = stream.next().await {
assert!(rtt < Duration::from_secs(1), "RTT too high: {rtt:?}");
count += 1;
}
})
.await;

assert!(
result.is_ok(),
"stream did not terminate - hung in while let loop"
);
assert_eq!(count, 3, "expected exactly 3 ping responses");
}

/// Test that an empty address list terminates immediately.
#[tokio::test(flavor = "current_thread")]
async fn stream_terminates_with_empty_input() {
let pinger = DualstackPinger::new().unwrap();
let addresses: Vec<IpAddr> = vec![];
let mut stream = pinger.measure_many(addresses.into_iter());

let result = time::timeout(Duration::from_secs(1), async {
let first = stream.next().await;
assert!(first.is_none(), "expected None for empty address list");
})
.await;

assert!(result.is_ok(), "stream did not terminate for empty input");
}

/// Test stream termination with multi_thread runtime as a baseline.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn stream_terminates_multi_thread() {
let pinger = DualstackPinger::new().unwrap();
let localhost: IpAddr = "127.0.0.1".parse().unwrap();
let mut stream = pinger.measure_many([localhost].into_iter());

let mut count = 0;

let result = time::timeout(Duration::from_secs(5), async {
while let Some((addr, rtt)) = stream.next().await {
assert_eq!(addr, localhost);
assert!(rtt < Duration::from_secs(1), "RTT too high: {rtt:?}");
count += 1;
}
})
.await;

assert!(
result.is_ok(),
"stream did not terminate - hung in while let loop"
);
assert_eq!(count, 1, "expected exactly 1 ping response");
}