diff --git a/src/either.rs b/src/either.rs index 5b00ec96..602ad255 100644 --- a/src/either.rs +++ b/src/either.rs @@ -1,3 +1,4 @@ +#[derive(Debug)] pub(crate) enum Either { Left(L), Right(R), diff --git a/src/lib.rs b/src/lib.rs index 1f830b77..0a1cef49 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,9 @@ mod throughput; #[cfg(feature = "tester")] pub use self::throughput::ThroughputMonitoring; +#[cfg(test)] +pub mod test_utils; + pub use self::config::Config; pub use self::error::{ErrorKind, Result}; pub use self::net::{LinkConditioner, Socket, SocketEvent}; diff --git a/src/net.rs b/src/net.rs index ede86d33..99ee0e01 100644 --- a/src/net.rs +++ b/src/net.rs @@ -2,6 +2,8 @@ //! You can think of the socket, connection management, congestion control. mod connection; +mod connection_impl; +mod connection_manager; mod events; mod link_conditioner; mod quality; @@ -10,6 +12,8 @@ mod virtual_connection; pub mod constants; +pub use self::connection::{Connection, ConnectionEventAddress, ConnectionMessenger}; +pub use self::connection_manager::{ConnectionManager, DatagramSocket}; pub use self::events::SocketEvent; pub use self::link_conditioner::LinkConditioner; pub use self::quality::{NetworkQuality, RttMeasurer}; diff --git a/src/net/connection.rs b/src/net/connection.rs index 2ee26360..37f7e5be 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -1,187 +1,73 @@ -pub use crate::net::{NetworkQuality, RttMeasurer, VirtualConnection}; - -use crate::config::Config; -use crate::either::Either::{self, Left, Right}; -use std::{ - collections::HashMap, - net::SocketAddr, - time::{Duration, Instant}, -}; - -/// Maintains a registry of active "connections". Essentially, when we receive a packet on the -/// socket from a particular `SocketAddr`, we will track information about it here. -#[derive(Debug)] -pub struct ActiveConnections { - connections: HashMap, -} - -impl ActiveConnections { - pub fn new() -> Self { - Self { - connections: HashMap::new(), - } - } - - /// Try to get a `VirtualConnection` by address. If the connection does not exist, it will be - /// inserted and returned. - pub fn get_or_insert_connection( - &mut self, - address: SocketAddr, - config: &Config, - time: Instant, - ) -> &mut VirtualConnection { - self.connections - .entry(address) - .or_insert_with(|| VirtualConnection::new(address, config, time)) - } - - /// Try to get or create a [VirtualConnection] by address. If the connection does not exist, it will be - /// created and returned, but not inserted into the table of active connections. - pub(crate) fn get_or_create_connection( - &mut self, - address: SocketAddr, - config: &Config, - time: Instant, - ) -> Either<&mut VirtualConnection, VirtualConnection> { - if let Some(connection) = self.connections.get_mut(&address) { - Left(connection) - } else { - Right(VirtualConnection::new(address, config, time)) - } - } - - /// Removes the connection from `ActiveConnections` by socket address. - pub fn remove_connection( - &mut self, - address: &SocketAddr, - ) -> Option<(SocketAddr, VirtualConnection)> { - self.connections.remove_entry(address) - } - - /// Check for and return `VirtualConnection`s which have been idling longer than `max_idle_time`. - pub fn idle_connections(&mut self, max_idle_time: Duration, time: Instant) -> Vec { - self.connections - .iter() - .filter(|(_, connection)| connection.last_heard(time) >= max_idle_time) - .map(|(address, _)| *address) - .collect() - } - - /// Get a list of addresses of dead connections - pub fn dead_connections(&mut self) -> Vec { - self.connections - .iter() - .filter(|(_, connection)| connection.should_be_dropped()) - .map(|(address, _)| *address) - .collect() - } - - /// Check for and return `VirtualConnection`s which have not sent anything for a duration of at least `heartbeat_interval`. - pub fn heartbeat_required_connections( - &mut self, - heartbeat_interval: Duration, - time: Instant, - ) -> impl Iterator { - self.connections - .iter_mut() - .filter(move |(_, connection)| connection.last_sent(time) >= heartbeat_interval) - .map(|(_, connection)| connection) - } - - /// Returns true if the given connection exists. - pub fn exists(&self, address: &SocketAddr) -> bool { - self.connections.contains_key(&address) - } - - /// Returns the number of connected clients. - #[cfg(test)] - pub(crate) fn count(&self) -> usize { - self.connections.len() - } -} - -#[cfg(test)] -mod tests { - use super::{ActiveConnections, Config}; - use std::{ - sync::Arc, - time::{Duration, Instant}, - }; - - const ADDRESS: &str = "127.0.0.1:12345"; - - #[test] - fn connection_timed_out() { - let mut connections = ActiveConnections::new(); - let config = Config::default(); - - let now = Instant::now(); - - // add 10 clients - for i in 0..10 { - connections.get_or_insert_connection( - format!("127.0.0.1:122{}", i).parse().unwrap(), - &config, - now, - ); - } - - assert_eq!(connections.count(), 10); - - let wait = Duration::from_millis(200); - - #[cfg(not(windows))] - let epsilon = Duration::from_nanos(1); - #[cfg(windows)] - let epsilon = Duration::from_millis(1); - - let timed_out_connections = connections.idle_connections(wait, now + wait - epsilon); - assert_eq!(timed_out_connections.len(), 0); - - let timed_out_connections = connections.idle_connections(wait, now + wait + epsilon); - assert_eq!(timed_out_connections.len(), 10); - } - - #[test] - fn insert_connection() { - let mut connections = ActiveConnections::new(); - let config = Config::default(); - - let address = ADDRESS.parse().unwrap(); - connections.get_or_insert_connection(address, &config, Instant::now()); - assert!(connections.connections.contains_key(&address)); - } - - #[test] - fn insert_existing_connection() { - let mut connections = ActiveConnections::new(); - let config = Config::default(); - - let address = ADDRESS.parse().unwrap(); - connections.get_or_insert_connection(address, &config, Instant::now()); - assert!(connections.connections.contains_key(&address)); - connections.get_or_insert_connection(address, &config, Instant::now()); - assert!(connections.connections.contains_key(&address)); - } - - #[test] - fn remove_connection() { - let mut connections = ActiveConnections::new(); - let config = Arc::new(Config::default()); - - let address = ADDRESS.parse().unwrap(); - connections.get_or_insert_connection(address, &config, Instant::now()); - assert!(connections.connections.contains_key(&address)); - connections.remove_connection(&address); - assert!(!connections.connections.contains_key(&address)); - } - - #[test] - fn remove_non_existent_connection() { - let mut connections = ActiveConnections::new(); - - let address = &ADDRESS.parse().unwrap(); - connections.remove_connection(address); - assert!(!connections.connections.contains_key(address)); - } -} +use crate::config::Config; + +use std::{self, fmt::Debug, net::SocketAddr, time::Instant}; + +/// Allows connection to send packet, send event and get global configuration. +pub trait ConnectionMessenger { + /// Returns global configuration. + fn config(&self) -> &Config; + + /// Sends a connection event. + fn send_event(&mut self, address: &SocketAddr, event: ReceiveEvent); + /// Sends a packet. + fn send_packet(&mut self, address: &SocketAddr, payload: &[u8]); +} + +/// Returns an address of an event. +/// This is used by a `ConnectionManager`, because it doesn't know anything about connection events. +pub trait ConnectionEventAddress { + /// Returns event address + fn address(&self) -> SocketAddr; +} + +/// Allows to implement actual connection. +/// Defines a type of `Send` and `Receive` events, that will be used by a connection. +pub trait Connection: Debug { + /// Defines a user event type. + type SendEvent: Debug + ConnectionEventAddress; + /// Defines a connection event type. + type ReceiveEvent: Debug + ConnectionEventAddress; + + /// Creates new connection and initialize it by sending an connection event to the user. + /// * messenger - allows to send packets and events, also provides a config. + /// * address - defines a address that connection is associated with. + /// * time - creation time, used by connection, so that it doesn't get dropped immediately or send heartbeat packet. + /// * initial_data - if initiated by remote host, this will hold that a packet data. + fn create_connection( + messenger: &mut impl ConnectionMessenger, + address: SocketAddr, + time: Instant, + initial_data: Option<&[u8]>, + ) -> Self; + + /// Determines if the connection should be dropped due to its state. + fn should_drop( + &mut self, + messenger: &mut impl ConnectionMessenger, + time: Instant, + ) -> bool; + + /// Processes a received packet: parse it and emit an event. + fn process_packet( + &mut self, + messenger: &mut impl ConnectionMessenger, + payload: &[u8], + time: Instant, + ); + + /// Processes a received event and send a packet. + fn process_event( + &mut self, + messenger: &mut impl ConnectionMessenger, + event: Self::SendEvent, + time: Instant, + ); + + /// Processes various connection-related tasks: resend dropped packets, send heartbeat packet, etc... + /// This function gets called frequently. + fn update( + &mut self, + messenger: &mut impl ConnectionMessenger, + time: Instant, + ); +} diff --git a/src/net/connection_impl.rs b/src/net/connection_impl.rs new file mode 100644 index 00000000..d1a4ac66 --- /dev/null +++ b/src/net/connection_impl.rs @@ -0,0 +1,174 @@ +use super::{ + events::SocketEvent, Connection, ConnectionEventAddress, ConnectionMessenger, VirtualConnection, +}; +use crate::error::{ErrorKind, Result}; +use crate::packet::{DeliveryGuarantee, OutgoingPackets, Packet, PacketInfo}; + +use std::net::SocketAddr; +use std::time::Instant; + +use log::error; + +/// Required by `ConnectionManager` to properly handle connection event. +impl ConnectionEventAddress for SocketEvent { + /// Returns event address + fn address(&self) -> SocketAddr { + match self { + SocketEvent::Packet(packet) => packet.addr(), + SocketEvent::Connect(addr) => *addr, + SocketEvent::Timeout(addr) => *addr, + } + } +} + +/// Required by `ConnectionManager` to properly handle user event. +impl ConnectionEventAddress for Packet { + /// Returns event address + fn address(&self) -> SocketAddr { + self.addr() + } +} + +impl Connection for VirtualConnection { + /// Defines a user event type. + type SendEvent = Packet; + /// Defines a connection event type. + type ReceiveEvent = SocketEvent; + + /// Creates new connection and initialize it by sending an connection event to the user. + /// * address - defines a address that connection is associated with. + /// * time - creation time, used by connection, so that it doesn't get dropped immediately or send heartbeat packet. + /// * initial_data - if initiated by remote host, this will hold that a packet data. + fn create_connection( + messenger: &mut impl ConnectionMessenger, + address: SocketAddr, + time: Instant, + initial_data: Option<&[u8]>, + ) -> VirtualConnection { + // Emit connect event if this is initiated by the remote host. + if initial_data.is_some() { + messenger.send_event(&address, SocketEvent::Connect(address)); + } + VirtualConnection::new(address, messenger.config(), time) + } + + /// Determines if the given `Connection` should be dropped due to its state. + fn should_drop( + &mut self, + messenger: &mut impl ConnectionMessenger, + time: Instant, + ) -> bool { + let should_drop = self.packets_in_flight() > messenger.config().max_packets_in_flight + || self.last_heard(time) >= messenger.config().idle_connection_timeout; + if should_drop { + messenger.send_event( + &self.remote_address, + SocketEvent::Timeout(self.remote_address), + ); + } + should_drop + } + + /// Processes a received packet: parse it and emit an event. + fn process_packet( + &mut self, + messenger: &mut impl ConnectionMessenger, + payload: &[u8], + time: Instant, + ) { + if !payload.is_empty() { + match self.process_incoming(payload, time) { + Ok(packets) => { + for incoming in packets { + messenger.send_event(&self.remote_address, SocketEvent::Packet(incoming.0)); + } + } + Err(err) => error!("Error occured processing incomming packet: {:?}", err), + } + } else { + error!( + "Error processing packet: {}", + ErrorKind::ReceivedDataToShort + ); + } + } + + /// Processes a received event and send a packet. + fn process_event( + &mut self, + messenger: &mut impl ConnectionMessenger, + event: Self::SendEvent, + time: Instant, + ) { + let addr = self.remote_address; + send_packets( + messenger, + &addr, + self.process_outgoing( + PacketInfo::user_packet( + event.payload(), + event.delivery_guarantee(), + event.order_guarantee(), + ), + None, + time, + ), + "user packet", + ); + } + + /// Processes various connection-related tasks: resend dropped packets, send heartbeat packet, etc... + /// This function gets called very frequently. + fn update( + &mut self, + messenger: &mut impl ConnectionMessenger, + time: Instant, + ) { + // Resend dropped packets + for dropped in self.gather_dropped_packets() { + let packets = self.process_outgoing( + PacketInfo { + packet_type: dropped.packet_type, + payload: &dropped.payload, + // Because a delivery guarantee is only sent with reliable packets + delivery: DeliveryGuarantee::Reliable, + // This is stored with the dropped packet because they could be mixed + ordering: dropped.ordering_guarantee, + }, + dropped.item_identifier, + time, + ); + send_packets(messenger, &self.remote_address, packets, "dropped packets"); + } + + // Send heartbeat packets if required + if let Some(heartbeat_interval) = messenger.config().heartbeat_interval { + let addr = self.remote_address; + if self.last_sent(time) >= heartbeat_interval { + send_packets( + messenger, + &addr, + self.process_outgoing(PacketInfo::heartbeat_packet(&[]), None, time), + "heatbeat packet", + ); + } + } + } +} + +// Sends multiple outgoing packets. +fn send_packets( + ctx: &mut impl ConnectionMessenger, + address: &SocketAddr, + packets: Result, + err_context: &str, +) { + match packets { + Ok(packets) => { + for outgoing in packets { + ctx.send_packet(address, &outgoing.contents()); + } + } + Err(error) => error!("Error occured processing {}: {:?}", err_context, error), + } +} diff --git a/src/net/connection_manager.rs b/src/net/connection_manager.rs new file mode 100644 index 00000000..ed14b6a6 --- /dev/null +++ b/src/net/connection_manager.rs @@ -0,0 +1,913 @@ +use crate::{ + config::Config, net::Connection, net::ConnectionEventAddress, net::ConnectionMessenger, +}; +use crossbeam_channel::{self, unbounded, Receiver, Sender}; +use log::error; +use std::{self, collections::HashMap, fmt::Debug, io::Result, net::SocketAddr, time::Instant}; + +// TODO: maybe we can make a breaking change and use this instead of `ConnectionEventAddress` trait? +// #[derive(Debug)] +// pub struct ConnectionEvent(pub SocketAddr, pub Event); + +/// A datagram socket is a type of network socket which provides a connectionless point for sending or receiving data packets. +pub trait DatagramSocket: Debug { + /// Sends a single packet to the socket. + fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> Result; + + /// Receives a single packet from the socket. + fn receive_packet<'a>(&mut self, buffer: &'a mut [u8]) -> Result<(&'a [u8], SocketAddr)>; + + /// Returns the socket address that this socket was created from. + fn local_addr(&self) -> Result; + + /// Returns whether socket operates in blocking or nonblocking mode. + fn is_blocking_mode(&self) -> bool; +} + +// This will be used by a `Connection`. +#[derive(Debug)] +struct SocketEventSenderAndConfig { + config: Config, + socket: TSocket, + event_sender: Sender, +} + +impl + SocketEventSenderAndConfig +{ + fn new(config: Config, socket: TSocket, event_sender: Sender) -> Self { + Self { + config, + socket, + event_sender, + } + } +} + +impl ConnectionMessenger + for SocketEventSenderAndConfig +{ + fn config(&self) -> &Config { + &self.config + } + + fn send_event(&mut self, _address: &SocketAddr, event: ReceiveEvent) { + self.event_sender.send(event).expect("Receiver must exists"); + } + + fn send_packet(&mut self, address: &SocketAddr, payload: &[u8]) { + if let Err(err) = self.socket.send_packet(address, payload) { + error!("Error occured sending a packet (to {}): {}", address, err) + } + } +} + +/// Implements a concept of connections on top of datagram socket. +/// Connection capabilities depends on what is an actual `Connection` type. +/// Connection type also defines a type of sending and receiving events. +#[derive(Debug)] +pub struct ConnectionManager { + connections: HashMap, + receive_buffer: Vec, + user_event_receiver: Receiver, + messenger: SocketEventSenderAndConfig, + // Stores event receiver, so that user can clone it. + event_receiver: Receiver, + // Stores event sender, so that user can clone it. + user_event_sender: Sender, +} + +impl ConnectionManager { + /// Creates an instance of `ConnectionManager` by passing a socket and config. + pub fn new(socket: TSocket, config: Config) -> Self { + let (event_sender, event_receiver) = unbounded(); + let (user_event_sender, user_event_receiver) = unbounded(); + ConnectionManager { + receive_buffer: vec![0; config.receive_buffer_max_size], + connections: Default::default(), + user_event_receiver, + messenger: SocketEventSenderAndConfig::new(config, socket, event_sender), + user_event_sender, + event_receiver, + } + } + + /// Process any inbound/outbound packets and events. + /// Process connection specific logic for active connections. + /// Remove dropped connections from active connections list. + pub fn manual_poll(&mut self, time: Instant) { + let messenger = &mut self.messenger; + // First we pull all newly arrived packets and handle them + loop { + match messenger + .socket + .receive_packet(self.receive_buffer.as_mut()) + { + Ok((payload, address)) => { + if let Some(conn) = self.connections.get_mut(&address) { + conn.process_packet(messenger, payload, time); + } else { + // Create connection, but do not add to active connections list + let mut conn = + TConnection::create_connection(messenger, address, time, Some(payload)); + conn.process_packet(messenger, payload, time); + } + } + Err(e) => { + if e.kind() != std::io::ErrorKind::WouldBlock { + error!("Encountered an error receiving data: {:?}", e); + } + break; + } + } + // To prevent from blocking, break after receiving first packet + if messenger.socket.is_blocking_mode() { + break; + } + } + + // Now grab all the waiting packets and send them + while let Ok(event) = self.user_event_receiver.try_recv() { + // get or create connection + let conn = self.connections.entry(event.address()).or_insert_with(|| { + TConnection::create_connection(messenger, event.address(), time, None) + }); + conn.process_event(messenger, event, time); + } + + // Update all connections + for conn in self.connections.values_mut() { + conn.update(messenger, time); + } + + // Iterate through all connections and remove those that should be dropped + self.connections + .retain(|_, conn| !conn.should_drop(messenger, time)); + } + + /// Returns a handle to the event sender which provides a thread-safe way to enqueue user events + /// to be processed. This should be used when the socket is busy running its polling loop in a + /// separate thread. + pub fn event_sender(&self) -> &Sender { + &self.user_event_sender + } + + /// Returns a handle to the event receiver which provides a thread-safe way to retrieve events + /// from the connections. This should be used when the socket is busy running its polling loop in + /// a separate thread. + pub fn event_receiver(&self) -> &Receiver { + &self.event_receiver + } + + /// Returns socket reference. + pub fn socket(&self) -> &TSocket { + &self.messenger.socket + } + + /// Returns socket mutable reference. + #[allow(dead_code)] + pub fn socket_mut(&mut self) -> &mut TSocket { + &mut self.messenger.socket + } + + /// Returns a number of active connections. + #[cfg(test)] + pub fn connections_count(&self) -> usize { + self.connections.len() + } +} + +#[cfg(test)] +mod tests { + use crate::net::LinkConditioner; + use crate::test_utils::*; + use crate::{Config, Packet, SocketEvent}; + + use std::{ + collections::HashSet, + net::SocketAddr, + time::{Duration, Instant}, + }; + + /// The socket address of where the server is located. + const SERVER_ADDR: &str = "127.0.0.1:10001"; + // The client address from where the data is sent. + const CLIENT_ADDR: &str = "127.0.0.1:10002"; + + fn client_address() -> SocketAddr { + CLIENT_ADDR.parse().unwrap() + } + + fn server_address() -> SocketAddr { + SERVER_ADDR.parse().unwrap() + } + + fn create_server_client_network() -> (FakeSocket, FakeSocket, NetworkEmulator) { + let network = NetworkEmulator::default(); + let server = FakeSocket::bind(&network, server_address(), Config::default()).unwrap(); + let client = FakeSocket::bind(&network, client_address(), Config::default()).unwrap(); + (server, client, network) + } + + fn create_server_client(config: Config) -> (FakeSocket, FakeSocket) { + let network = NetworkEmulator::default(); + let server = FakeSocket::bind(&network, server_address(), config.clone()).unwrap(); + let client = FakeSocket::bind(&network, client_address(), config).unwrap(); + (server, client) + } + + #[test] + fn using_sender_and_receiver() { + let (mut server, mut client, _) = create_server_client_network(); + + let sender = client.get_packet_sender(); + let receiver = server.get_event_receiver(); + + sender + .send(Packet::reliable_unordered( + server_address(), + b"Hello world!".to_vec(), + )) + .unwrap(); + + let time = Instant::now(); + client.manual_poll(time); + server.manual_poll(time); + + assert_eq![Ok(SocketEvent::Connect(client_address())), receiver.recv()]; + if let SocketEvent::Packet(packet) = receiver.recv().unwrap() { + assert_eq![b"Hello world!", packet.payload()]; + } else { + panic!["Did not receive a packet when it should"]; + } + } + + #[test] + fn initial_packet_is_resent() { + let (mut server, mut client, network) = create_server_client_network(); + let time = Instant::now(); + + // Send a packet that the server ignores/drops + client + .send(Packet::reliable_unordered( + server_address(), + b"Do not arrive".to_vec(), + )) + .unwrap(); + client.manual_poll(time); + + // Drop the inbound packet, this simulates a network error + network.clear_packets(server_address()); + + // Send a packet that the server receives + for id in 0..u8::max_value() { + client + .send(Packet::reliable_unordered(server_address(), vec![id])) + .unwrap(); + + server + .send(Packet::reliable_unordered(client_address(), vec![id])) + .unwrap(); + + client.manual_poll(time); + server.manual_poll(time); + + while let Some(SocketEvent::Packet(pkt)) = server.recv() { + if pkt.payload() == b"Do not arrive" { + return; + } + } + while let Some(_) = client.recv() {} + } + + panic!["Did not receive the ignored packet"]; + } + + #[test] + fn receiving_does_not_allow_denial_of_service() { + let (mut server, mut client, _) = create_server_client_network(); + // Send a bunch of packets to a server + for _ in 0..3 { + client + .send(Packet::unreliable( + server_address(), + vec![1, 2, 3, 4, 5, 6, 7, 8, 9], + )) + .unwrap(); + } + + let time = Instant::now(); + + client.manual_poll(time); + server.manual_poll(time); + + for _ in 0..6 { + assert![server.recv().is_some()]; + } + assert![server.recv().is_none()]; + + // The server shall not have any connection in its connection table even though it received + // packets + assert_eq![0, server.connection_count()]; + + server + .send(Packet::unreliable(client_address(), vec![1])) + .unwrap(); + + server.manual_poll(time); + + // The server only adds to its table after having sent explicitly + assert_eq![1, server.connection_count()]; + } + + #[test] + fn initial_sequenced_is_resent() { + let (mut server, mut client, network) = create_server_client_network(); + let time = Instant::now(); + + // Send a packet that the server ignores/drops + client + .send(Packet::reliable_sequenced( + server_address(), + b"Do not arrive".to_vec(), + None, + )) + .unwrap(); + client.manual_poll(time); + + // Drop the inbound packet, this simulates a network error + network.clear_packets(server_address()); + + // Send a packet that the server receives + for id in 0..36 { + client + .send(Packet::reliable_sequenced(server_address(), vec![id], None)) + .unwrap(); + + server + .send(Packet::reliable_sequenced(client_address(), vec![id], None)) + .unwrap(); + + client.manual_poll(time); + server.manual_poll(time); + + while let Some(SocketEvent::Packet(pkt)) = server.recv() { + if pkt.payload() == b"Do not arrive" { + panic!["Sequenced packet arrived while it should not"]; + } + } + while let Some(_) = client.recv() {} + } + } + + #[test] + fn initial_ordered_is_resent() { + let (mut server, mut client, network) = create_server_client_network(); + let time = Instant::now(); + + // Send a packet that the server ignores/drops + client + .send(Packet::reliable_ordered( + server_address(), + b"Do not arrive".to_vec(), + None, + )) + .unwrap(); + client.manual_poll(time); + + // Drop the inbound packet, this simulates a network error + network.clear_packets(server_address()); + + // Send a packet that the server receives + for id in 0..35 { + client + .send(Packet::reliable_ordered(server_address(), vec![id], None)) + .unwrap(); + + server + .send(Packet::reliable_ordered(client_address(), vec![id], None)) + .unwrap(); + + client.manual_poll(time); + server.manual_poll(time); + + while let Some(SocketEvent::Packet(pkt)) = server.recv() { + if pkt.payload() == b"Do not arrive" { + return; + } + } + while let Some(_) = client.recv() {} + } + + panic!["Did not receive the ignored packet"]; + } + + #[test] + fn do_not_duplicate_sequenced_packets_when_received() { + let (mut server, mut client, _) = create_server_client_network(); + let time = Instant::now(); + + for id in 0..100 { + client + .send(Packet::reliable_sequenced(server_address(), vec![id], None)) + .unwrap(); + client.manual_poll(time); + server.manual_poll(time); + } + + let mut seen = HashSet::new(); + + while let Some(message) = server.recv() { + match message { + SocketEvent::Connect(_) => {} + SocketEvent::Packet(packet) => { + let byte = packet.payload()[0]; + assert![!seen.contains(&byte)]; + seen.insert(byte); + } + SocketEvent::Timeout(_) => { + panic!["This should not happen, as we've not advanced time"]; + } + } + } + + assert_eq![100, seen.len()]; + } + + #[test] + fn more_than_65536_sequenced_packets() { + let (mut server, mut client, _) = create_server_client_network(); + // Acknowledge the client + server + .send(Packet::unreliable(client_address(), vec![0])) + .unwrap(); + + let time = Instant::now(); + + for id in 0..65536 + 100 { + client + .send(Packet::unreliable_sequenced( + server_address(), + id.to_string().as_bytes().to_vec(), + None, + )) + .unwrap(); + client.manual_poll(time); + server.manual_poll(time); + } + + let mut cnt = 0; + while let Some(message) = server.recv() { + match message { + SocketEvent::Connect(_) => {} + SocketEvent::Packet(_) => { + cnt += 1; + } + SocketEvent::Timeout(_) => { + panic!["This should not happen, as we've not advanced time"]; + } + } + } + assert_eq![65536 + 100, cnt]; + } + + #[test] + fn sequenced_packets_pathological_case() { + let mut config = Config::default(); + config.max_packets_in_flight = 100; + let (_, mut client) = create_server_client(config.clone()); + + let time = Instant::now(); + + for id in 0..101 { + client + .send(Packet::reliable_sequenced( + server_address(), + id.to_string().as_bytes().to_vec(), + None, + )) + .unwrap(); + client.manual_poll(time); + + while let Some(event) = client.recv() { + match event { + SocketEvent::Timeout(remote_addr) => { + assert_eq![100, id]; + assert_eq![remote_addr, server_address()]; + return; + } + _ => { + panic!["No other event possible"]; + } + } + } + } + + panic!["Should have received a timeout event"]; + } + + #[test] + fn manual_polling_socket() { + let (mut server, mut client, _) = create_server_client_network(); + for _ in 0..3 { + client + .send(Packet::unreliable( + server_address(), + vec![1, 2, 3, 4, 5, 6, 7, 8, 9], + )) + .unwrap(); + } + + let time = Instant::now(); + + client.manual_poll(time); + server.manual_poll(time); + + assert!(server.recv().is_some()); + assert!(server.recv().is_some()); + assert!(server.recv().is_some()); + } + + #[test] + fn can_send_and_receive() { + let (mut server, mut client, _) = create_server_client_network(); + for _ in 0..3 { + client + .send(Packet::unreliable( + server_address(), + vec![1, 2, 3, 4, 5, 6, 7, 8, 9], + )) + .unwrap(); + } + + let now = Instant::now(); + client.manual_poll(now); + server.manual_poll(now); + + assert!(server.recv().is_some()); + assert!(server.recv().is_some()); + assert!(server.recv().is_some()); + } + + #[test] + fn connect_event_occurs() { + let (mut server, mut client, _) = create_server_client_network(); + + client + .send(Packet::unreliable(server_address(), vec![0, 1, 2])) + .unwrap(); + + let now = Instant::now(); + client.manual_poll(now); + server.manual_poll(now); + + assert_eq!( + server.recv().unwrap(), + SocketEvent::Connect(client_address()) + ); + } + + #[test] + fn disconnect_event_occurs() { + let mut config = Config::default(); + config.idle_connection_timeout = Duration::from_millis(1); + let (mut server, mut client) = create_server_client(config.clone()); + + client + .send(Packet::unreliable(server_address(), vec![0, 1, 2])) + .unwrap(); + + let now = Instant::now(); + client.manual_poll(now); + server.manual_poll(now); + + assert_eq!( + server.recv().unwrap(), + SocketEvent::Connect(client_address()) + ); + assert_eq!( + server.recv().unwrap(), + SocketEvent::Packet(Packet::unreliable(client_address(), vec![0, 1, 2])) + ); + + // Acknowledge the client + server + .send(Packet::unreliable(client_address(), vec![])) + .unwrap(); + + server.manual_poll(now); + client.manual_poll(now); + + // Make sure the connection was successful on the client side + assert_eq!( + client.recv().unwrap(), + SocketEvent::Packet(Packet::unreliable(server_address(), vec![])) + ); + + // Give just enough time for no timeout events to occur (yet) + server.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1)); + client.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1)); + + assert_eq!(server.recv(), None); + assert_eq!(client.recv(), None); + + // Give enough time for timeouts to be detected + server.manual_poll(now + config.idle_connection_timeout); + client.manual_poll(now + config.idle_connection_timeout); + + assert_eq!( + server.recv().unwrap(), + SocketEvent::Timeout(client_address()) + ); + assert_eq!( + client.recv().unwrap(), + SocketEvent::Timeout(server_address()) + ); + } + + #[test] + fn heartbeats_work() { + let mut config = Config::default(); + config.idle_connection_timeout = Duration::from_millis(10); + config.heartbeat_interval = Some(Duration::from_millis(4)); + let (mut server, mut client) = create_server_client(config.clone()); + // Initiate a connection + client + .send(Packet::unreliable(server_address(), vec![0, 1, 2])) + .unwrap(); + + let now = Instant::now(); + client.manual_poll(now); + server.manual_poll(now); + + // Make sure the connection was successful on the server side + assert_eq!( + server.recv().unwrap(), + SocketEvent::Connect(client_address()) + ); + assert_eq!( + server.recv().unwrap(), + SocketEvent::Packet(Packet::unreliable(client_address(), vec![0, 1, 2])) + ); + + // Acknowledge the client + // This way, the server also knows about the connection and sends heartbeats + server + .send(Packet::unreliable(client_address(), vec![])) + .unwrap(); + + server.manual_poll(now); + client.manual_poll(now); + + // Make sure the connection was successful on the client side + assert_eq!( + client.recv().unwrap(), + SocketEvent::Packet(Packet::unreliable(server_address(), vec![])) + ); + + // Give time to send heartbeats + client.manual_poll(now + config.heartbeat_interval.unwrap()); + server.manual_poll(now + config.heartbeat_interval.unwrap()); + + // Give time for timeouts to occur if no heartbeats were sent + client.manual_poll(now + config.idle_connection_timeout); + server.manual_poll(now + config.idle_connection_timeout); + + // Assert that no disconnection events occurred + assert_eq!(client.recv(), None); + assert_eq!(server.recv(), None); + } + + #[test] + fn multiple_sends_should_start_sending_dropped() { + let (mut server, mut client, _) = create_server_client_network(); + + let now = Instant::now(); + + // Send enough packets to ensure that we must have dropped packets. + for i in 0..35 { + client + .send(Packet::unreliable(server_address(), vec![i])) + .unwrap(); + client.manual_poll(now); + } + + let mut events = Vec::new(); + + loop { + server.manual_poll(now); + if let Some(event) = server.recv() { + events.push(event); + } else { + break; + } + } + + // Ensure that we get the correct number of events to the server. + // 35 connect events plus the 35 messages + assert_eq!(events.len(), 70); + + // Finally the server decides to send us a message back. This necessarily will include + // the ack information for 33 of the sent 35 packets. + server + .send(Packet::unreliable(client_address(), vec![0])) + .unwrap(); + server.manual_poll(now); + + // Loop to ensure that the client gets the server message before moving on. + loop { + client.manual_poll(now); + if client.recv().is_some() { + break; + } + } + + // This next sent message should end up sending the 2 unacked messages plus the new messages + // with payload 35 + events.clear(); + client + .send(Packet::unreliable(server_address(), vec![35])) + .unwrap(); + client.manual_poll(now); + + loop { + server.manual_poll(now); + if let Some(event) = server.recv() { + events.push(event); + break; + } + } + + let sent_events: Vec = events + .iter() + .flat_map(|e| match e { + SocketEvent::Packet(p) => Some(p.payload()[0]), + _ => None, + }) + .collect(); + assert_eq!(sent_events, vec![35]); + } + + #[test] + fn really_bad_network_keeps_chugging_along() { + let (mut server, mut client, _) = create_server_client_network(); + + let time = Instant::now(); + + // We give both the server and the client a really bad bidirectional link + let link_conditioner = { + let mut lc = LinkConditioner::new(); + lc.set_packet_loss(0.9); + Some(lc) + }; + + client.set_link_conditioner(link_conditioner.clone()); + server.set_link_conditioner(link_conditioner); + + let mut set = HashSet::new(); + + // We chat 100 packets between the client and server, which will re-send any non-acked + // packets + let mut send_many_packets = |dummy: Option| { + for id in 0..100 { + client + .send(Packet::reliable_unordered( + server_address(), + vec![dummy.unwrap_or(id)], + )) + .unwrap(); + + server + .send(Packet::reliable_unordered(client_address(), vec![255])) + .unwrap(); + + client.manual_poll(time); + server.manual_poll(time); + + while let Some(_) = client.recv() {} + while let Some(event) = server.recv() { + match event { + SocketEvent::Packet(pkt) => { + set.insert(pkt.payload()[0]); + } + SocketEvent::Timeout(_) => { + panic!["Unable to time out, time has not advanced"] + } + SocketEvent::Connect(_) => {} + } + } + } + + set.len() + }; + + // The first chatting sequence sends packets 0..100 from the client to the server. After + // this we just chat with a value of 255 so we don't accidentally overlap those chatting + // packets with the packets we want to ack. + send_many_packets(None); + send_many_packets(Some(255)); + send_many_packets(Some(255)); + send_many_packets(Some(255)); + + // 101 because we have 0..100 and 255 from the dummies + assert_eq![101, send_many_packets(Some(255))]; + } + + #[test] + fn fragmented_ordered_gets_acked() { + let mut config = Config::default(); + config.fragment_size = 10; + let (mut server, mut client) = create_server_client(config.clone()); + + let time = Instant::now(); + let dummy = vec![0]; + + // --- + + client + .send(Packet::unreliable(server_address(), dummy.clone())) + .unwrap(); + client.manual_poll(time); + server + .send(Packet::unreliable(client_address(), dummy.clone())) + .unwrap(); + server.manual_poll(time); + + // --- + + let exceeds = b"Fragmented string".to_vec(); + client + .send(Packet::reliable_ordered(server_address(), exceeds, None)) + .unwrap(); + client.manual_poll(time); + + server.manual_poll(time); + server.manual_poll(time); + server + .send(Packet::reliable_ordered( + client_address(), + dummy.clone(), + None, + )) + .unwrap(); + + client + .send(Packet::unreliable(server_address(), dummy.clone())) + .unwrap(); + client.manual_poll(time); + server.manual_poll(time); + + for _ in 0..4 { + assert![server.recv().is_some()]; + } + assert![server.recv().is_none()]; + + for _ in 0..34 { + client + .send(Packet::reliable_ordered( + server_address(), + dummy.clone(), + None, + )) + .unwrap(); + client.manual_poll(time); + server + .send(Packet::reliable_ordered( + client_address(), + dummy.clone(), + None, + )) + .unwrap(); + server.manual_poll(time); + assert![client.recv().is_some()]; + // If the last iteration returns None here, it indicates we just received a re-sent + // fragment, because `manual_poll` only processes a single incoming UDP packet per + // `manual_poll` if and only if the socket is in blocking mode. + // + // If that functionality is changed, we will receive something unexpected here + match server.recv() { + Some(SocketEvent::Packet(pkt)) => { + assert_eq![dummy, pkt.payload()]; + } + _ => { + panic!["Did not receive expected dummy packet"]; + } + } + } + } + + #[quickcheck_macros::quickcheck] + fn do_not_panic_on_arbitrary_packets(bytes: Vec) { + use crate::net::DatagramSocket; + let network = NetworkEmulator::default(); + let mut server = FakeSocket::bind(&network, server_address(), Config::default()).unwrap(); + let mut client_socket = network.new_socket(client_address()).unwrap(); + + client_socket + .send_packet(&server_address(), &bytes) + .unwrap(); + + let time = Instant::now(); + server.manual_poll(time); + } +} diff --git a/src/net/link_conditioner.rs b/src/net/link_conditioner.rs index a7913ecd..55b97baf 100644 --- a/src/net/link_conditioner.rs +++ b/src/net/link_conditioner.rs @@ -8,7 +8,7 @@ use rand_pcg::Pcg64Mcg as Random; use std::time::Duration; /// Network simulator. Used to simulate network conditions as dropped packets and packet delays. -/// For use in [Socket::set_link_conditioner](crate::net::Socket::set_link_conditioner). +/// For use in [FakeSocket::set_link_conditioner](crate::test_utils::FakeSocket::set_link_conditioner). #[derive(Clone, Debug)] pub struct LinkConditioner { // Value between 0 and 1, representing the % change a packet will be dropped on sending diff --git a/src/net/quality.rs b/src/net/quality.rs index 2968c6b3..39547a50 100644 --- a/src/net/quality.rs +++ b/src/net/quality.rs @@ -80,7 +80,7 @@ impl RttMeasurer { mod test { use super::RttMeasurer; use crate::config::Config; - use crate::net::connection::VirtualConnection; + use crate::net::VirtualConnection; use std::net::ToSocketAddrs; use std::time::{Duration, Instant}; diff --git a/src/net/socket.rs b/src/net/socket.rs index dd425fe4..f7b84636 100644 --- a/src/net/socket.rs +++ b/src/net/socket.rs @@ -1,79 +1,83 @@ -use crate::either::Either::{Left, Right}; use crate::{ config::Config, - error::{ErrorKind, Result}, - net::{connection::ActiveConnections, events::SocketEvent, link_conditioner::LinkConditioner}, - packet::{DeliveryGuarantee, Packet, PacketInfo}, + error::Result, + net::{ + events::SocketEvent, ConnectionManager, DatagramSocket, LinkConditioner, VirtualConnection, + }, + packet::Packet, }; -use crossbeam_channel::{self, unbounded, Receiver, SendError, Sender, TryRecvError}; -use log::error; + +use crossbeam_channel::{self, Receiver, Sender, TryRecvError}; use std::{ - self, io, + self, net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket}, thread::{sleep, yield_now}, time::{Duration, Instant}, }; -// Wrap `LinkConditioner` and `UdpSocket` together +// Wrap `LinkConditioner` and `UdpSocket` together. LinkConditioner is enabled when building with a "tester" feature. #[derive(Debug)] struct SocketWithConditioner { + is_blocking_mode: bool, socket: UdpSocket, link_conditioner: Option, } impl SocketWithConditioner { - /// Creates an instance of `SocketWithConditioner` - pub fn new(socket: UdpSocket, link_conditioner: Option) -> Self { - Self { + pub fn new(socket: UdpSocket, is_blocking_mode: bool) -> Result { + socket.set_nonblocking(!is_blocking_mode)?; + Ok(SocketWithConditioner { + is_blocking_mode, socket, - link_conditioner, - } + link_conditioner: None, + }) } - // In the presence of a link conditioner, we would like it to determine whether or not we should - // send a single packet over the UDP socket. - pub fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> Result { - if let Some(ref mut link) = self.link_conditioner { - if !link.should_send() { - return Ok(0); + #[cfg(feature = "tester")] + pub fn set_link_conditioner(&mut self, link_conditioner: Option) { + self.link_conditioner = link_conditioner; + } +} + +/// Provides a `DatagramSocket` implementation for `SocketWithConditioner` +impl DatagramSocket for SocketWithConditioner { + // When `LinkConditioner` is enabled, it will determine whether packet will be sent or not. + fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> std::io::Result { + if cfg!(feature = "tester") { + if let Some(ref mut link) = &mut self.link_conditioner { + if !link.should_send() { + return Ok(0); + } } } - Ok(self.socket.send_to(payload, addr)?) + self.socket.send_to(payload, addr) } - /// Returns mutable reference of `UdpSocket` - pub fn socket(&mut self) -> &mut UdpSocket { - &mut self.socket + /// Receive a single packet from UDP socket. + fn receive_packet<'a>( + &mut self, + buffer: &'a mut [u8], + ) -> std::io::Result<(&'a [u8], SocketAddr)> { + self.socket + .recv_from(buffer) + .map(move |(recv_len, address)| (&buffer[..recv_len], address)) } - /// Returns the local socket address - pub fn local_addr(&self) -> Result { - Ok(self.socket.local_addr()?) + /// Returns the socket address that this socket was created from. + fn local_addr(&self) -> std::io::Result { + self.socket.local_addr() } - /// Set the link conditioner for this socket. See [LinkConditioner] for further details. - pub fn set_link_conditioner(&mut self, conditioner: Option) { - self.link_conditioner = conditioner; + /// Returns whether socket operates in blocking or nonblocking mode. + fn is_blocking_mode(&self) -> bool { + self.is_blocking_mode } } /// A reliable UDP socket implementation with configurable reliability and ordering guarantees. #[derive(Debug)] pub struct Socket { - socket_wrapper: SocketWithConditioner, - config: Config, - connections: ActiveConnections, - recv_buffer: Vec, - event_sender: Sender, - packet_receiver: Receiver, - - receiver: Receiver, - sender: Sender, -} - -enum UdpSocketState { - MaybeEmpty, - MaybeMore, + handler: ConnectionManager, } impl Socket { @@ -81,7 +85,7 @@ impl Socket { /// Because UDP connections are not persistent, we can only infer the status of the remote /// endpoint by looking to see if they are still sending packets or not pub fn bind(addresses: A) -> Result { - Socket::bind_with_config(addresses, Config::default()) + Self::bind_with_config(addresses, Config::default()) } /// Bind to any local port on the system, if available @@ -108,49 +112,40 @@ impl Socket { } fn bind_internal(socket: UdpSocket, config: Config) -> Result { - socket.set_nonblocking(!config.blocking_mode)?; - let (event_sender, event_receiver) = unbounded(); - let (packet_sender, packet_receiver) = unbounded(); Ok(Socket { - recv_buffer: vec![0; config.receive_buffer_max_size], - socket_wrapper: SocketWithConditioner::new(socket, None), - config, - connections: ActiveConnections::new(), - event_sender, - packet_receiver, - - sender: packet_sender, - receiver: event_receiver, + handler: ConnectionManager::new( + SocketWithConditioner::new(socket, config.blocking_mode)?, + config, + ), }) } /// Returns a handle to the packet sender which provides a thread-safe way to enqueue packets /// to be processed. This should be used when the socket is busy running its polling loop in a /// separate thread. - pub fn get_packet_sender(&mut self) -> Sender { - self.sender.clone() + pub fn get_packet_sender(&self) -> Sender { + self.handler.event_sender().clone() } /// Returns a handle to the event receiver which provides a thread-safe way to retrieve events /// from the socket. This should be used when the socket is busy running its polling loop in /// a separate thread. - pub fn get_event_receiver(&mut self) -> Receiver { - self.receiver.clone() + pub fn get_event_receiver(&self) -> Receiver { + self.handler.event_receiver().clone() } /// Send a packet pub fn send(&mut self, packet: Packet) -> Result<()> { - match self.sender.send(packet) { - Ok(_) => Ok(()), - Err(error) => Err(ErrorKind::SendError(SendError(SocketEvent::Packet( - error.0, - )))), - } + self.handler + .event_sender() + .send(packet) + .expect("Receiver must exists."); + Ok(()) } /// Receive a packet pub fn recv(&mut self) -> Option { - match self.receiver.try_recv() { + match self.handler.event_receiver().try_recv() { Ok(pkt) => Some(pkt), Err(TryRecvError::Empty) => None, Err(TryRecvError::Disconnected) => panic!["This can never happen"], @@ -168,7 +163,7 @@ impl Socket { pub fn start_polling_with_duration(&mut self, sleep_duration: Option) { // Nothing should break out of this loop! loop { - self.manual_poll(Instant::now()); + self.handler.manual_poll(Instant::now()); match sleep_duration { None => yield_now(), Some(duration) => sleep(duration), @@ -178,1140 +173,19 @@ impl Socket { /// Process any inbound/outbound packets and handle idle clients pub fn manual_poll(&mut self, time: Instant) { - // First we pull all newly arrived packets and handle them - loop { - match self.recv_from(time) { - Ok(UdpSocketState::MaybeMore) => continue, - Ok(UdpSocketState::MaybeEmpty) => break, - Err(e) => error!("Encountered an error receiving data: {:?}", e), - } - } - - // Now grab all the packets waiting to be sent and send them - while let Ok(p) = self.packet_receiver.try_recv() { - if let Err(e) = self.send_to(p, time) { - match e { - ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - _ => error!("There was an error sending packet: {:?}", e), - } - } - } - - // Check for idle clients - if let Err(e) = self.handle_idle_clients(time) { - error!("Encountered an error when sending TimeoutEvent: {:?}", e); - } - - // Handle any dead clients - self.handle_dead_clients().expect("Internal laminar error"); - - // Finally send heartbeat packets to connections that require them, if enabled - if let Some(heartbeat_interval) = self.config.heartbeat_interval { - if let Err(e) = self.send_heartbeat_packets(heartbeat_interval, time) { - match e { - ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - _ => error!("There was an error sending a heartbeat packet: {:?}", e), - } - } - } - } - - /// Set the link conditioner for this socket. See [LinkConditioner] for further details. - pub fn set_link_conditioner(&mut self, link_conditioner: Option) { - self.socket_wrapper.set_link_conditioner(link_conditioner); + self.handler.manual_poll(time); } /// Returns the local socket address pub fn local_addr(&self) -> Result { - self.socket_wrapper.local_addr() - } - - /// Iterate through the dead connections and disconnect them by removing them from the - /// connection map while informing the user of this by sending an event. - fn handle_dead_clients(&mut self) -> Result<()> { - let dead_addresses = self.connections.dead_connections(); - for address in dead_addresses { - self.connections.remove_connection(&address); - self.event_sender.send(SocketEvent::Timeout(address))?; - } - - Ok(()) - } - - /// Iterate through all of the idle connections based on `idle_connection_timeout` config and - /// remove them from the active connections. For each connection removed, we will send a - /// `SocketEvent::TimeOut` event to the `event_sender` channel. - fn handle_idle_clients(&mut self, time: Instant) -> Result<()> { - let idle_addresses = self - .connections - .idle_connections(self.config.idle_connection_timeout, time); - for address in idle_addresses { - self.connections.remove_connection(&address); - self.event_sender.send(SocketEvent::Timeout(address))?; - } - - Ok(()) - } - - /// Iterate over all connections which have not sent a packet for a duration of at least - /// `heartbeat_interval` (from config), and send a heartbeat packet to each. - fn send_heartbeat_packets( - &mut self, - heartbeat_interval: Duration, - time: Instant, - ) -> Result { - let heartbeat_packets_and_addrs = self - .connections - .heartbeat_required_connections(heartbeat_interval, time) - .map(|connection| { - ( - connection.process_outgoing(PacketInfo::heartbeat_packet(&[]), None, time), - connection.remote_address, - ) - }) - .collect::>(); - - let mut bytes_sent = 0; - - for (heartbeat_packet, address) in heartbeat_packets_and_addrs { - let packet = heartbeat_packet? - .into_iter() - .next() - .expect("Heartbeat packet must exists"); - bytes_sent += self - .socket_wrapper - .send_packet(&address, &packet.contents())?; - } - - Ok(bytes_sent) - } - - // Serializes and sends a `Packet` on the socket. On success, returns the number of bytes written. - fn send_to(&mut self, packet: Packet, time: Instant) -> Result { - let connection = - self.connections - .get_or_insert_connection(packet.addr(), &self.config, time); - - let mut bytes_sent = 0; - - // TODO maybe dropped packets shouldn't depend on how often a user sends a packet? - let dropped_packets = connection.gather_dropped_packets(); - for dropped in dropped_packets { - let packets = connection.process_outgoing( - PacketInfo { - packet_type: dropped.packet_type, - payload: &dropped.payload, - // Because a delivery guarantee is only sent with reliable packets - delivery: DeliveryGuarantee::Reliable, - // This is stored with the dropped packet because they could be mixed - ordering: dropped.ordering_guarantee, - }, - dropped.item_identifier, - time, - )?; - - for outgoing in packets { - bytes_sent += self - .socket_wrapper - .send_packet(&packet.addr(), &outgoing.contents())?; - } - } - - let packets = connection.process_outgoing( - PacketInfo::user_packet( - packet.payload(), - packet.delivery_guarantee(), - packet.order_guarantee(), - ), - None, - time, - )?; - for outgoing in packets { - bytes_sent += self - .socket_wrapper - .send_packet(&packet.addr(), &outgoing.contents())?; - } - Ok(bytes_sent) - } - - // On success the packet will be sent on the `event_sender` - fn recv_from(&mut self, time: Instant) -> Result { - match self - .socket_wrapper - .socket() - .recv_from(&mut self.recv_buffer) - { - Ok((recv_len, address)) => { - if recv_len == 0 { - return Err(ErrorKind::ReceivedDataToShort); - } - let received_payload = &self.recv_buffer[..recv_len]; - - if !self.connections.exists(&address) { - self.event_sender.send(SocketEvent::Connect(address))?; - } - - let connection = - self.connections - .get_or_create_connection(address, &self.config, time); - - let packets = match connection { - Left(existing) => existing.process_incoming(received_payload, time)?, - Right(mut anonymous) => anonymous.process_incoming(received_payload, time)?, - }; - for incoming in packets { - self.event_sender - .send(SocketEvent::Packet(incoming.0)) - .unwrap(); - } - } - Err(e) => { - if e.kind() != io::ErrorKind::WouldBlock { - error!("Encountered an error receiving data: {:?}", e); - return Err(e.into()); - } else { - return Ok(UdpSocketState::MaybeEmpty); - } - } - } - - if self.config.blocking_mode { - Ok(UdpSocketState::MaybeEmpty) - } else { - Ok(UdpSocketState::MaybeMore) - } - } - - #[cfg(test)] - fn connection_count(&self) -> usize { - self.connections.count() - } - - #[cfg(test)] - fn forget_all_incoming_packets(&mut self) { - std::thread::sleep(std::time::Duration::from_millis(100)); - self.socket_wrapper.socket().set_nonblocking(true).unwrap(); - loop { - match self - .socket_wrapper - .socket() - .recv_from(&mut self.recv_buffer) - { - Ok((recv_len, _address)) => { - if recv_len == 0 { - panic!("Received data too short"); - } - } - Err(e) => { - if e.kind() != io::ErrorKind::WouldBlock { - panic!("Encountered an error receiving data: {:?}", e); - } else { - self.socket_wrapper - .socket() - .set_nonblocking(!self.config.blocking_mode) - .unwrap(); - return; - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use crate::{ - net::constants::{ACKED_PACKET_HEADER, FRAGMENT_HEADER_SIZE, STANDARD_HEADER_SIZE}, - Config, LinkConditioner, Packet, Socket, SocketEvent, - }; - use std::collections::HashSet; - use std::net::{SocketAddr, UdpSocket}; - use std::time::{Duration, Instant}; - - #[test] - fn binding_to_any() { - assert![Socket::bind_any().is_ok()]; - assert![Socket::bind_any_with_config(Config::default()).is_ok()]; - } - - #[test] - fn blocking_sender_and_receiver() { - let cfg = Config::default(); - - let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap(); - let mut server = Socket::bind_any_with_config(Config { - blocking_mode: true, - ..cfg - }) - .unwrap(); - - let server_addr = server.local_addr().unwrap(); - let client_addr = client.local_addr().unwrap(); - - let time = Instant::now(); - - client - .send(Packet::unreliable(server_addr, b"Hello world!".to_vec())) - .unwrap(); - - client.manual_poll(time); - server.manual_poll(time); - - assert_eq![SocketEvent::Connect(client_addr), server.recv().unwrap()]; - if let SocketEvent::Packet(packet) = server.recv().unwrap() { - assert_eq![b"Hello world!", packet.payload()]; - } else { - panic!["Did not receive a packet when it should"]; - } - } - - #[test] - fn using_sender_and_receiver() { - let server_addr = "127.0.0.1:12310".parse::().unwrap(); - let client_addr = "127.0.0.1:12311".parse::().unwrap(); - - let mut server = Socket::bind(server_addr).unwrap(); - let mut client = Socket::bind(client_addr).unwrap(); - - let time = Instant::now(); - - let sender = client.get_packet_sender(); - let receiver = server.get_event_receiver(); - - sender - .send(Packet::reliable_unordered( - server_addr, - b"Hello world!".to_vec(), - )) - .unwrap(); - - client.manual_poll(time); - server.manual_poll(time); - - assert_eq![Ok(SocketEvent::Connect(client_addr)), receiver.recv()]; - if let SocketEvent::Packet(packet) = receiver.recv().unwrap() { - assert_eq![b"Hello world!", packet.payload()]; - } else { - panic!["Did not receive a packet when it should"]; - } - } - - #[test] - fn initial_packet_is_resent() { - let mut server = Socket::bind("127.0.0.1:12335".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12336".parse::().unwrap()).unwrap(); - - let time = Instant::now(); - - // Send a packet that the server ignores/drops - client - .send(Packet::reliable_unordered( - "127.0.0.1:12335".parse::().unwrap(), - b"Do not arrive".to_vec(), - )) - .unwrap(); - client.manual_poll(time); - - // Drop the inbound packet, this simulates a network error - server.forget_all_incoming_packets(); - - // Send a packet that the server receives - for id in 0..u8::max_value() { - client - .send(create_test_packet(id, "127.0.0.1:12335")) - .unwrap(); - - server - .send(create_test_packet(id, "127.0.0.1:12336")) - .unwrap(); - - client.manual_poll(time); - server.manual_poll(time); - - while let Some(SocketEvent::Packet(pkt)) = server.recv() { - if pkt.payload() == b"Do not arrive" { - return; - } - } - while let Some(_) = client.recv() {} - } - - panic!["Did not receive the ignored packet"]; - } - - #[test] - fn receiving_does_not_allow_denial_of_service() { - let mut server = Socket::bind("127.0.0.1:12337".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12338".parse::().unwrap()).unwrap(); - - // Send a bunch of packets to a server - for _ in 0..3 { - client - .send(Packet::unreliable( - "127.0.0.1:12337".parse::().unwrap(), - vec![1, 2, 3, 4, 5, 6, 7, 8, 9], - )) - .unwrap(); - } - - let time = Instant::now(); - - client.manual_poll(time); - server.manual_poll(time); - - for _ in 0..6 { - assert![server.recv().is_some()]; - } - assert![server.recv().is_none()]; - - // The server shall not have any connection in its connection table even though it received - // packets - assert_eq![0, server.connection_count()]; - - server - .send(Packet::unreliable( - "127.0.0.1:12338".parse::().unwrap(), - vec![1], - )) - .unwrap(); - - server.manual_poll(time); - - // The server only adds to its table after having sent explicitly - assert_eq![1, server.connection_count()]; - } - - #[test] - fn initial_sequenced_is_resent() { - let mut server = Socket::bind("127.0.0.1:12329".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12330".parse::().unwrap()).unwrap(); - - let time = Instant::now(); - - // Send a packet that the server ignores/drops - client - .send(Packet::reliable_sequenced( - "127.0.0.1:12329".parse::().unwrap(), - b"Do not arrive".to_vec(), - None, - )) - .unwrap(); - client.manual_poll(time); - - // Drop the inbound packet, this simulates a network error - server.forget_all_incoming_packets(); - - // Send a packet that the server receives - for id in 0..36 { - client - .send(create_sequenced_packet(id, "127.0.0.1:12329")) - .unwrap(); - - server - .send(create_sequenced_packet(id, "127.0.0.1:12330")) - .unwrap(); - - client.manual_poll(time); - server.manual_poll(time); - - while let Some(SocketEvent::Packet(pkt)) = server.recv() { - if pkt.payload() == b"Do not arrive" { - panic!["Sequenced packet arrived while it should not"]; - } - } - while let Some(_) = client.recv() {} - } - } - - #[test] - fn initial_ordered_is_resent() { - let mut server = Socket::bind("127.0.0.1:12333".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12334".parse::().unwrap()).unwrap(); - - let time = Instant::now(); - - // Send a packet that the server ignores/drops - client - .send(Packet::reliable_ordered( - "127.0.0.1:12333".parse::().unwrap(), - b"Do not arrive".to_vec(), - None, - )) - .unwrap(); - client.manual_poll(time); - - // Drop the inbound packet, this simulates a network error - server.forget_all_incoming_packets(); - - // Send a packet that the server receives - for id in 0..35 { - client - .send(create_ordered_packet(id, "127.0.0.1:12333")) - .unwrap(); - - server - .send(create_ordered_packet(id, "127.0.0.1:12334")) - .unwrap(); - - client.manual_poll(time); - server.manual_poll(time); - - while let Some(SocketEvent::Packet(pkt)) = server.recv() { - if pkt.payload() == b"Do not arrive" { - return; - } - } - while let Some(_) = client.recv() {} - } - - panic!["Did not receive the ignored packet"]; - } - - #[test] - fn do_not_duplicate_sequenced_packets_when_received() { - let mut config = Config::default(); - - let mut client = Socket::bind_any_with_config(config.clone()).unwrap(); - config.blocking_mode = true; - let mut server = Socket::bind_any_with_config(config).unwrap(); - - let server_addr = server.local_addr().unwrap(); - let _client_addr = client.local_addr().unwrap(); - - let time = Instant::now(); - - for id in 0..100 { - client - .send(Packet::reliable_sequenced(server_addr, vec![id], None)) - .unwrap(); - client.manual_poll(time); - server.manual_poll(time); - } - - let mut seen = HashSet::new(); - - while let Some(message) = server.recv() { - match message { - SocketEvent::Connect(_) => {} - SocketEvent::Packet(packet) => { - let byte = packet.payload()[0]; - assert![!seen.contains(&byte)]; - seen.insert(byte); - } - SocketEvent::Timeout(_) => { - panic!["This should not happen, as we've not advanced time"]; - } - } - } - - assert_eq![100, seen.len()]; - } - - #[test] - fn more_than_65536_sequenced_packets() { - let mut config = Config::default(); - - let mut client = Socket::bind_any_with_config(config.clone()).unwrap(); - config.blocking_mode = true; - let mut server = Socket::bind_any_with_config(config).unwrap(); - - let server_addr = server.local_addr().unwrap(); - let client_addr = client.local_addr().unwrap(); - - // Acknowledge the client - server - .send(Packet::unreliable(client_addr, vec![0])) - .unwrap(); - - let time = Instant::now(); - - for id in 0..65536 + 100 { - client - .send(Packet::unreliable_sequenced( - server_addr, - id.to_string().as_bytes().to_vec(), - None, - )) - .unwrap(); - client.manual_poll(time); - server.manual_poll(time); - } - - let mut cnt = 0; - while let Some(message) = server.recv() { - match message { - SocketEvent::Connect(_) => {} - SocketEvent::Packet(_) => { - cnt += 1; - } - SocketEvent::Timeout(_) => { - panic!["This should not happen, as we've not advanced time"]; - } - } - } - assert_eq![65536 + 100, cnt]; + Ok(self.handler.socket().local_addr()?) } - #[test] - fn sequenced_packets_pathological_case() { - let mut config = Config::default(); - - config.max_packets_in_flight = 100; - let mut client = Socket::bind_any_with_config(config.clone()).unwrap(); - config.blocking_mode = true; - let server = Socket::bind_any_with_config(config).unwrap(); - - let server_addr = server.local_addr().unwrap(); - - let time = Instant::now(); - - for id in 0..101 { - client - .send(Packet::reliable_sequenced( - server_addr, - id.to_string().as_bytes().to_vec(), - None, - )) - .unwrap(); - client.manual_poll(time); - - while let Some(event) = client.recv() { - match event { - SocketEvent::Timeout(remote_addr) => { - assert_eq![100, id]; - assert_eq![remote_addr, server_addr]; - return; - } - _ => { - panic!["No other event possible"]; - } - } - } - } - - panic!["Should have received a timeout event"]; - } - - #[test] - fn manual_polling_socket() { - let mut server = Socket::bind("127.0.0.1:12339".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12340".parse::().unwrap()).unwrap(); - - for _ in 0..3 { - client - .send(Packet::unreliable( - "127.0.0.1:12339".parse::().unwrap(), - vec![1, 2, 3, 4, 5, 6, 7, 8, 9], - )) - .unwrap(); - } - - let time = Instant::now(); - - client.manual_poll(time); - server.manual_poll(time); - - assert!(server.recv().is_some()); - assert!(server.recv().is_some()); - assert!(server.recv().is_some()); - } - - #[test] - fn can_send_and_receive() { - let mut server = Socket::bind("127.0.0.1:12342".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12341".parse::().unwrap()).unwrap(); - - for _ in 0..3 { - client - .send(Packet::unreliable( - "127.0.0.1:12342".parse::().unwrap(), - vec![1, 2, 3, 4, 5, 6, 7, 8, 9], - )) - .unwrap(); - } - - let now = Instant::now(); - client.manual_poll(now); - server.manual_poll(now); - - assert!(server.recv().is_some()); - assert!(server.recv().is_some()); - assert!(server.recv().is_some()); - } - - #[test] - fn sending_large_unreliable_packet_should_fail() { - let mut server = Socket::bind("127.0.0.1:12370".parse::().unwrap()).unwrap(); - - assert_eq!( - server - .send_to( - Packet::unreliable("127.0.0.1:12360".parse().unwrap(), vec![1; 5000]), - Instant::now(), - ) - .is_err(), - true - ); - } - - #[test] - fn send_returns_right_size() { - let mut server = Socket::bind("127.0.0.1:12371".parse::().unwrap()).unwrap(); - - assert_eq!( - server - .send_to( - Packet::unreliable("127.0.0.1:12361".parse().unwrap(), vec![1; 1024]), - Instant::now(), - ) - .unwrap(), - 1024 + STANDARD_HEADER_SIZE as usize - ); - } - - #[test] - fn fragmentation_send_returns_right_size() { - let mut server = Socket::bind("127.0.0.1:12372".parse::().unwrap()).unwrap(); - - let fragment_packet_size = STANDARD_HEADER_SIZE + FRAGMENT_HEADER_SIZE; - - // the first fragment of an sequence of fragments contains also the acknowledgment header. - assert_eq!( - server - .send_to( - Packet::reliable_unordered("127.0.0.1:12362".parse().unwrap(), vec![1; 4000]), - Instant::now(), - ) - .unwrap(), - 4000 + (fragment_packet_size * 4 + ACKED_PACKET_HEADER) as usize - ); - } - - #[test] - fn connect_event_occurs() { - let mut server = Socket::bind("127.0.0.1:12345".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12344".parse::().unwrap()).unwrap(); - - client - .send(Packet::unreliable( - "127.0.0.1:12345".parse().unwrap(), - vec![0, 1, 2], - )) - .unwrap(); - - let now = Instant::now(); - client.manual_poll(now); - server.manual_poll(now); - - assert_eq!( - server.recv().unwrap(), - SocketEvent::Connect("127.0.0.1:12344".parse().unwrap()) - ); - } - - #[test] - fn disconnect_event_occurs() { - let mut config = Config::default(); - config.idle_connection_timeout = Duration::from_millis(1); - - let server_addr = "127.0.0.1:12347".parse::().unwrap(); - let client_addr = "127.0.0.1:12346".parse::().unwrap(); - - let mut server = Socket::bind_with_config(server_addr, config.clone()).unwrap(); - let mut client = Socket::bind_with_config(client_addr, config.clone()).unwrap(); - - client - .send(Packet::unreliable(server_addr, vec![0, 1, 2])) - .unwrap(); - - let now = Instant::now(); - client.manual_poll(now); - server.manual_poll(now); - - assert_eq!(server.recv().unwrap(), SocketEvent::Connect(client_addr)); - assert_eq!( - server.recv().unwrap(), - SocketEvent::Packet(Packet::unreliable(client_addr, vec![0, 1, 2])) - ); - - // Acknowledge the client - server - .send(Packet::unreliable(client_addr, vec![])) - .unwrap(); - - server.manual_poll(now); - client.manual_poll(now); - - // Make sure the connection was successful on the client side - assert_eq!( - client.recv().unwrap(), - SocketEvent::Packet(Packet::unreliable(server_addr, vec![])) - ); - - // Give just enough time for no timeout events to occur (yet) - server.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1)); - client.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1)); - - assert_eq!(server.recv(), None); - assert_eq!(client.recv(), None); - - // Give enough time for timeouts to be detected - server.manual_poll(now + config.idle_connection_timeout); - client.manual_poll(now + config.idle_connection_timeout); - - assert_eq!(server.recv().unwrap(), SocketEvent::Timeout(client_addr)); - assert_eq!(client.recv().unwrap(), SocketEvent::Timeout(server_addr)); - } - - #[test] - fn heartbeats_work() { - let mut config = Config::default(); - config.idle_connection_timeout = Duration::from_millis(10); - config.heartbeat_interval = Some(Duration::from_millis(4)); - - let server_addr = "127.0.0.1:12351".parse::().unwrap(); - let client_addr = "127.0.0.1:12352".parse::().unwrap(); - - // Start up a server and a client. - let mut server = Socket::bind_with_config(server_addr, config.clone()).unwrap(); - let mut client = Socket::bind_with_config(client_addr, config.clone()).unwrap(); - - // Initiate a connection - client - .send(Packet::unreliable(server_addr, vec![0, 1, 2])) - .unwrap(); - - let now = Instant::now(); - client.manual_poll(now); - server.manual_poll(now); - - // Make sure the connection was successful on the server side - assert_eq!(server.recv().unwrap(), SocketEvent::Connect(client_addr)); - assert_eq!( - server.recv().unwrap(), - SocketEvent::Packet(Packet::unreliable(client_addr, vec![0, 1, 2])) - ); - - // Acknowledge the client - // This way, the server also knows about the connection and sends heartbeats - server - .send(Packet::unreliable(client_addr, vec![])) - .unwrap(); - - server.manual_poll(now); - client.manual_poll(now); - - // Make sure the connection was successful on the client side - assert_eq!( - client.recv().unwrap(), - SocketEvent::Packet(Packet::unreliable(server_addr, vec![])) - ); - - // Give time to send heartbeats - client.manual_poll(now + config.heartbeat_interval.unwrap()); - server.manual_poll(now + config.heartbeat_interval.unwrap()); - - // Give time for timeouts to occur if no heartbeats were sent - client.manual_poll(now + config.idle_connection_timeout); - server.manual_poll(now + config.idle_connection_timeout); - - // Assert that no disconnection events occurred - assert_eq!(client.recv(), None); - assert_eq!(server.recv(), None); - } - - fn create_test_packet(id: u8, addr: &str) -> Packet { - let payload = vec![id]; - Packet::reliable_unordered(addr.parse().unwrap(), payload) - } - - fn create_ordered_packet(id: u8, addr: &str) -> Packet { - let payload = vec![id]; - Packet::reliable_ordered(addr.parse().unwrap(), payload, None) - } - - fn create_sequenced_packet(id: u8, addr: &str) -> Packet { - let payload = vec![id]; - Packet::reliable_sequenced(addr.parse().unwrap(), payload, None) - } - - #[test] - fn multiple_sends_should_start_sending_dropped() { - const LOCAL_ADDR: &str = "127.0.0.1:13000"; - const REMOTE_ADDR: &str = "127.0.0.1:14000"; - - // Start up a server and a client. - let mut server = Socket::bind(REMOTE_ADDR.parse::().unwrap()).unwrap(); - let mut client = Socket::bind(LOCAL_ADDR.parse::().unwrap()).unwrap(); - - let now = Instant::now(); - - // Send enough packets to ensure that we must have dropped packets. - for i in 0..35 { - client.send(create_test_packet(i, REMOTE_ADDR)).unwrap(); - client.manual_poll(now); - } - - let mut events = Vec::new(); - - loop { - server.manual_poll(now); - if let Some(event) = server.recv() { - events.push(event); - } else { - break; - } - } - - // Ensure that we get the correct number of events to the server. - // 35 connect events plus the 35 messages - assert_eq!(events.len(), 70); - - // Finally the server decides to send us a message back. This necessarily will include - // the ack information for 33 of the sent 35 packets. - server.send(create_test_packet(0, LOCAL_ADDR)).unwrap(); - server.manual_poll(now); - - // Loop to ensure that the client gets the server message before moving on. - loop { - client.manual_poll(now); - if client.recv().is_some() { - break; - } - } - - // This next sent message should end up sending the 2 unacked messages plus the new messages - // with payload 35 - events.clear(); - client.send(create_test_packet(35, REMOTE_ADDR)).unwrap(); - client.manual_poll(now); - - loop { - server.manual_poll(now); - if let Some(event) = server.recv() { - events.push(event); - break; - } - } - - let sent_events: Vec = events - .iter() - .flat_map(|e| match e { - SocketEvent::Packet(p) => Some(p.payload()[0]), - _ => None, - }) - .collect(); - assert_eq!(sent_events, vec![35]); - } - - #[quickcheck_macros::quickcheck] - fn do_not_panic_on_arbitrary_packets(bytes: Vec) { - let receiver = "127.0.0.1:12332".parse::().unwrap(); - let sender = "127.0.0.1:12331".parse::().unwrap(); - - let mut server = Socket::bind(receiver).unwrap(); - - let client = UdpSocket::bind(sender).unwrap(); - - client.send_to(&bytes, receiver).unwrap(); - - let time = Instant::now(); - server.manual_poll(time); - } - - #[test] - fn really_bad_network_keeps_chugging_along() { - let server_addr = "127.0.0.1:12320".parse::().unwrap(); - let client_addr = "127.0.0.1:12321".parse::().unwrap(); - - let mut server = Socket::bind(server_addr).unwrap(); - let mut client = Socket::bind(client_addr).unwrap(); - - let time = Instant::now(); - - // We give both the server and the client a really bad bidirectional link - let link_conditioner = { - let mut lc = LinkConditioner::new(); - lc.set_packet_loss(0.9); - Some(lc) - }; - - client.set_link_conditioner(link_conditioner.clone()); - server.set_link_conditioner(link_conditioner); - - let mut set = HashSet::new(); - - // We chat 100 packets between the client and server, which will re-send any non-acked - // packets - let mut send_many_packets = |dummy: Option| { - for id in 0..100 { - client - .send(Packet::reliable_unordered( - server_addr, - vec![dummy.unwrap_or(id)], - )) - .unwrap(); - - server - .send(Packet::reliable_unordered(client_addr, vec![255])) - .unwrap(); - - client.manual_poll(time); - server.manual_poll(time); - - while let Some(_) = client.recv() {} - while let Some(event) = server.recv() { - match event { - SocketEvent::Packet(pkt) => { - set.insert(pkt.payload()[0]); - } - SocketEvent::Timeout(_) => { - panic!["Unable to time out, time has not advanced"] - } - SocketEvent::Connect(_) => {} - } - } - } - - set.len() - }; - - // The first chatting sequence sends packets 0..100 from the client to the server. After - // this we just chat with a value of 255 so we don't accidentally overlap those chatting - // packets with the packets we want to ack. - send_many_packets(None); - send_many_packets(Some(255)); - send_many_packets(Some(255)); - send_many_packets(Some(255)); - - // 101 because we have 0..100 and 255 from the dummies - assert_eq![101, send_many_packets(Some(255))]; - } - - #[test] - fn local_addr() { - let port = 40000; - let socket = - Socket::bind(format!("127.0.0.1:{}", port).parse::().unwrap()).unwrap(); - assert_eq!(port, socket.local_addr().unwrap().port()); - } - - #[test] - fn ordered_16_bit_overflow() { - let mut cfg = Config::default(); - - let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap(); - let client_addr = client.local_addr().unwrap(); - - cfg.blocking_mode = false; - let mut server = Socket::bind_any_with_config(cfg).unwrap(); - let server_addr = server.local_addr().unwrap(); - - let time = Instant::now(); - - let mut last_payload = String::new(); - - for idx in 0..100_000u64 { - client - .send(Packet::reliable_ordered( - server_addr, - idx.to_string().as_bytes().to_vec(), - None, - )) - .unwrap(); - - client.manual_poll(time); - - while let Some(_) = client.recv() {} - server - .send(Packet::reliable_ordered(client_addr, vec![123], None)) - .unwrap(); - server.manual_poll(time); - - while let Some(msg) = server.recv() { - if let SocketEvent::Packet(pkt) = msg { - last_payload = std::str::from_utf8(pkt.payload()).unwrap().to_string(); - } - } - } - - assert_eq!["99999", last_payload]; - } - - #[test] - fn fragmented_ordered_gets_acked() { - let mut cfg = Config::default(); - cfg.fragment_size = 10; - - let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap(); - let client_addr = client.local_addr().unwrap(); - - cfg.blocking_mode = true; - let mut server = Socket::bind_any_with_config(cfg).unwrap(); - let server_addr = server.local_addr().unwrap(); - - let time = Instant::now(); - let dummy = vec![0]; - - // --- - - client - .send(Packet::unreliable(server_addr, dummy.clone())) - .unwrap(); - client.manual_poll(time); - server - .send(Packet::unreliable(client_addr, dummy.clone())) - .unwrap(); - server.manual_poll(time); - - // --- - - let exceeds = b"Fragmented string".to_vec(); - client - .send(Packet::reliable_ordered(server_addr, exceeds, None)) - .unwrap(); - client.manual_poll(time); - - server.manual_poll(time); - server.manual_poll(time); - server - .send(Packet::reliable_ordered(client_addr, dummy.clone(), None)) - .unwrap(); - - client - .send(Packet::unreliable(server_addr, dummy.clone())) - .unwrap(); - client.manual_poll(time); - server.manual_poll(time); - - for _ in 0..4 { - assert![server.recv().is_some()]; - } - assert![server.recv().is_none()]; - - for _ in 0..34 { - client - .send(Packet::reliable_ordered(server_addr, dummy.clone(), None)) - .unwrap(); - client.manual_poll(time); - server - .send(Packet::reliable_ordered(client_addr, dummy.clone(), None)) - .unwrap(); - server.manual_poll(time); - assert![client.recv().is_some()]; - // If the last iteration returns None here, it indicates we just received a re-sent - // fragment, because `manual_poll` only processes a single incoming UDP packet per - // `manual_poll` if and only if the socket is in blocking mode. - // - // If that functionality is changed, we will receive something unexpected here - match server.recv() { - Some(SocketEvent::Packet(pkt)) => { - assert_eq![dummy, pkt.payload()]; - } - _ => { - panic!["Did not receive expected dummy packet"]; - } - } - } + /// Set the link conditioner for this socket. See [LinkConditioner] for further details. + #[cfg(feature = "tester")] + pub fn set_link_conditioner(&mut self, link_conditioner: Option) { + self.handler + .socket_mut() + .set_link_conditioner(link_conditioner); } } diff --git a/src/net/virtual_connection.rs b/src/net/virtual_connection.rs index ceb913c1..8244e031 100644 --- a/src/net/virtual_connection.rs +++ b/src/net/virtual_connection.rs @@ -54,9 +54,8 @@ impl VirtualConnection { } } - /// Determine if this connection should be dropped due to its state - pub fn should_be_dropped(&self) -> bool { - self.acknowledge_handler.packets_in_flight() > self.config.max_packets_in_flight + pub fn packets_in_flight(&self) -> u16 { + self.acknowledge_handler.packets_in_flight() } /// Returns a [Duration] representing the interval since we last heard from the client @@ -939,4 +938,115 @@ mod tests { panic!("Expected not fragmented packet") } } + + #[test] + fn sending_large_unreliable_packet_should_fail() { + let mut connection = create_virtual_connection(); + let buffer = vec![1; 5000]; + + let res = connection.process_outgoing( + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Unreliable, + OrderingGuarantee::None, + ), + None, + Instant::now(), + ); + + assert_eq!(res.is_err(), true); + } + + #[test] + fn send_returns_right_size() { + let mut connection = create_virtual_connection(); + let buffer = vec![1; 1024]; + + let mut packets = connection + .process_outgoing( + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Unreliable, + OrderingGuarantee::None, + ), + None, + Instant::now(), + ) + .unwrap() + .into_iter(); + let packet = packets.next().unwrap(); + + assert_eq!( + packet.contents().len(), + 1024 + constants::STANDARD_HEADER_SIZE as usize + ); + assert_eq!(packets.next().is_none(), true); + } + + #[test] + fn fragmentation_send_returns_right_size() { + let fragment_packet_size = + constants::STANDARD_HEADER_SIZE + constants::FRAGMENT_HEADER_SIZE; + + let mut connection = create_virtual_connection(); + let buffer = vec![1; 4000]; + + let packets = connection + .process_outgoing( + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Reliable, + OrderingGuarantee::None, + ), + None, + Instant::now(), + ) + .unwrap() + .into_iter(); + + // the first fragment of an sequence of fragments contains also the acknowledgment header. + assert_eq!( + packets.fold(0, |acc, p| acc + p.contents().len()), + 4000 + (fragment_packet_size * 4 + constants::ACKED_PACKET_HEADER) as usize + ); + } + + #[test] + fn ordered_16_bit_overflow() { + let mut send_conn = create_virtual_connection(); + let mut recv_conn = create_virtual_connection(); + + let time = Instant::now(); + let mut last_recv_value = 0u32; + for idx in 1..100_000u32 { + let data_to_send = idx.to_ne_bytes(); + let packet_sent = send_conn + .process_outgoing( + PacketInfo::user_packet( + &data_to_send, + DeliveryGuarantee::Reliable, + OrderingGuarantee::None, + ), + None, + time, + ) + .unwrap() + .into_iter() + .next() + .unwrap(); + + let packets = recv_conn + .process_incoming(&packet_sent.contents(), time) + .unwrap(); + + for (packet, _) in packets.into_iter() { + let mut recv_buff = [0; 4]; + recv_buff.copy_from_slice(packet.payload()); + let value = u32::from_ne_bytes(recv_buff); + assert_eq!(value, last_recv_value + 1); + last_recv_value = value; + } + } + assert_eq![last_recv_value, 99_999]; + } } diff --git a/src/packet/outgoing.rs b/src/packet/outgoing.rs index e622d513..bc3d6177 100644 --- a/src/packet/outgoing.rs +++ b/src/packet/outgoing.rs @@ -104,6 +104,7 @@ impl<'p> OutgoingPacketBuilder<'p> { } /// Packet that that contains data which is ready to be sent to a remote endpoint. +#[derive(Debug)] pub struct OutgoingPacket<'p> { header: Vec, payload: &'p [u8], diff --git a/src/packet/process_result.rs b/src/packet/process_result.rs index 1313c465..40bf4ce8 100644 --- a/src/packet/process_result.rs +++ b/src/packet/process_result.rs @@ -5,6 +5,7 @@ use std::collections::VecDeque; /// Struct that implements `Iterator`, and is used to return incoming (from bytes to packets) or outgoing (from packet to bytes) packets. /// It is used as optimization in cases, where most of the time there is only one element to iterate, and we don't want to create a vector for it. +#[derive(Debug)] pub struct ZeroOrMore { data: Either, VecDeque>, } @@ -41,6 +42,7 @@ impl Iterator for ZeroOrMore { } /// Stores packets with headers that will be sent to the network, implements `IntoIterator` for convenience. +#[derive(Debug)] pub struct OutgoingPackets<'a> { data: ZeroOrMore>, } @@ -71,6 +73,7 @@ impl<'a> IntoIterator for OutgoingPackets<'a> { } /// Stores parsed packets with their types, that was received from network, implements `IntoIterator` for convenience. +#[derive(Debug)] pub struct IncomingPackets { data: ZeroOrMore<(Packet, PacketType)>, } diff --git a/src/test_utils/fake_socket.rs b/src/test_utils/fake_socket.rs new file mode 100644 index 00000000..a0ce0b97 --- /dev/null +++ b/src/test_utils/fake_socket.rs @@ -0,0 +1,65 @@ +use crate::net::{ConnectionManager, LinkConditioner, VirtualConnection}; +use crate::test_utils::*; +use crate::{error::Result, Config, Packet, SocketEvent}; +use crossbeam_channel::{Receiver, Sender}; + +use std::{net::SocketAddr, time::Instant}; + +/// Provides a similar to the real a `Socket`, but with emulated socket implementation. +pub struct FakeSocket { + handler: ConnectionManager, +} + +impl FakeSocket { + /// Binds to the socket. + pub fn bind(network: &NetworkEmulator, addr: SocketAddr, config: Config) -> Result { + Ok(Self { + handler: ConnectionManager::new(network.new_socket(addr)?, config), + }) + } + + /// Returns a handle to the packet sender which provides a thread-safe way to enqueue packets + /// to be processed. This should be used when the socket is busy running its polling loop in a + /// separate thread. + pub fn get_packet_sender(&self) -> Sender { + self.handler.event_sender().clone() + } + + /// Returns a handle to the event receiver which provides a thread-safe way to retrieve events + /// from the socket. This should be used when the socket is busy running its polling loop in + /// a separate thread. + pub fn get_event_receiver(&self) -> Receiver { + self.handler.event_receiver().clone() + } + + /// Sends a packet. + pub fn send(&mut self, packet: Packet) -> Result<()> { + // we can savely unwrap, because receiver will always exist + self.handler.event_sender().send(packet).unwrap(); + Ok(()) + } + + /// Receives a packet. + pub fn recv(&mut self) -> Option { + if let Ok(event) = self.handler.event_receiver().try_recv() { + Some(event) + } else { + None + } + } + + /// Processes any inbound/outbound packets and handle idle clients. + pub fn manual_poll(&mut self, time: Instant) { + self.handler.manual_poll(time); + } + + /// Returns a number of active connections. + pub fn connection_count(&self) -> usize { + self.handler.connections_count() + } + + /// Sets the link conditioner for this socket. See [LinkConditioner] for further details. + pub fn set_link_conditioner(&mut self, conditioner: Option) { + self.handler.socket_mut().set_link_conditioner(conditioner); + } +} diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs new file mode 100644 index 00000000..45460f82 --- /dev/null +++ b/src/test_utils/mod.rs @@ -0,0 +1,5 @@ +mod fake_socket; +mod network_emulator; + +pub use fake_socket::FakeSocket; +pub use network_emulator::{EmulatedSocket, NetworkEmulator}; diff --git a/src/test_utils/network_emulator.rs b/src/test_utils/network_emulator.rs new file mode 100644 index 00000000..9d0f00fb --- /dev/null +++ b/src/test_utils/network_emulator.rs @@ -0,0 +1,106 @@ +use crate::net::{DatagramSocket, LinkConditioner}; + +use std::{ + cell::RefCell, + collections::hash_map::Entry, + collections::{HashMap, VecDeque}, + io::Result, + net::SocketAddr, + rc::Rc, +}; + +/// This type allows to share global state between all sockets, created from the same instance of `NetworkEmulator`. +type GlobalBindings = Rc)>>>>; + +/// Enables to create the emulated socket, that share global state stored by this network emulator. +#[derive(Debug, Default)] +pub struct NetworkEmulator { + network: GlobalBindings, +} + +impl NetworkEmulator { + /// Creates an emulated socket by binding to an address. + /// If other socket already was bound to this address, error will be returned instead. + pub fn new_socket(&self, address: SocketAddr) -> Result { + match self.network.borrow_mut().entry(address) { + Entry::Occupied(_) => Err(std::io::Error::new( + std::io::ErrorKind::AddrInUse, + "Cannot bind to address", + )), + Entry::Vacant(entry) => { + entry.insert(Default::default()); + Ok(EmulatedSocket { + network: self.network.clone(), + address, + conditioner: Default::default(), + }) + } + } + } + + /// Clear all packets from a socket that is bound to provided address. + pub fn clear_packets(&self, addr: SocketAddr) { + if let Some(packets) = self.network.borrow_mut().get_mut(&addr) { + packets.clear(); + } + } +} + +/// Implementation of a socket, that is created by `NetworkEmulator`. +#[derive(Debug, Clone)] +pub struct EmulatedSocket { + network: GlobalBindings, + address: SocketAddr, + conditioner: Option, +} + +impl EmulatedSocket { + pub fn set_link_conditioner(&mut self, conditioner: Option) { + self.conditioner = conditioner; + } +} + +impl DatagramSocket for EmulatedSocket { + /// Sends a packet to and address if there is a socket bound to it. Otherwise it will simply be ignored. + fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> Result { + let send = if let Some(ref mut conditioner) = self.conditioner { + conditioner.should_send() + } else { + true + }; + if send { + if let Some(binded) = self.network.borrow_mut().get_mut(addr) { + binded.push_back((self.address, payload.to_vec())); + } + Ok(payload.len()) + } else { + Ok(0) + } + } + + /// Receives a packet from this socket. + fn receive_packet<'a>(&mut self, buffer: &'a mut [u8]) -> Result<(&'a [u8], SocketAddr)> { + if let Some((addr, payload)) = self + .network + .borrow_mut() + .get_mut(&self.address) + .unwrap() + .pop_front() + { + let slice = &mut buffer[..payload.len()]; + slice.copy_from_slice(payload.as_ref()); + Ok((slice, addr)) + } else { + Err(std::io::ErrorKind::WouldBlock.into()) + } + } + + /// Returns the socket address that this socket was created from. + fn local_addr(&self) -> Result { + Ok(self.address) + } + + fn is_blocking_mode(&self) -> bool { + false + } +} diff --git a/tests/basic_socket_test.rs b/tests/basic_socket_test.rs new file mode 100644 index 00000000..0e3de47b --- /dev/null +++ b/tests/basic_socket_test.rs @@ -0,0 +1,146 @@ +#[cfg(feature = "tester")] +use laminar::LinkConditioner; +use laminar::{Config, Packet, Socket, SocketEvent}; + +use std::{collections::HashSet, net::SocketAddr, time::Instant}; + +#[test] +fn binding_to_any() { + // bind to 10 different addresses + let sock_without_config = (0..5).map(|_| Socket::bind_any()); + let sock_with_config = (0..5).map(|_| Socket::bind_any_with_config(Config::default())); + + let valid_socks: Vec<_> = sock_without_config + .chain(sock_with_config) + .filter_map(|sock| sock.ok()) + .collect(); + assert_eq!(valid_socks.len(), 10); + + let unique_addresses: HashSet<_> = valid_socks + .into_iter() + .map(|sock| sock.local_addr().unwrap()) + .collect(); + assert_eq!(unique_addresses.len(), 10); +} + +#[test] +fn blocking_sender_and_receiver() { + let cfg = Config::default(); + + let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap(); + let mut server = Socket::bind_any_with_config(Config { + blocking_mode: true, + ..cfg + }) + .unwrap(); + + let server_addr = server.local_addr().unwrap(); + let client_addr = client.local_addr().unwrap(); + + let time = Instant::now(); + + client + .send(Packet::unreliable(server_addr, b"Hello world!".to_vec())) + .unwrap(); + + client.manual_poll(time); + server.manual_poll(time); + + assert_eq![SocketEvent::Connect(client_addr), server.recv().unwrap()]; + if let SocketEvent::Packet(packet) = server.recv().unwrap() { + assert_eq![b"Hello world!", packet.payload()]; + } else { + panic!["Did not receive a packet when it should"]; + } +} + +#[test] +fn local_addr() { + let port = 40000; + let socket = + Socket::bind(format!("127.0.0.1:{}", port).parse::().unwrap()).unwrap(); + assert_eq!(port, socket.local_addr().unwrap().port()); +} + +#[test] +#[cfg(feature = "tester")] +fn use_link_conditioner() { + let mut client = Socket::bind_any().unwrap(); + let mut server = Socket::bind_any().unwrap(); + + let server_addr = server.local_addr().unwrap(); + + let link_conditioner = { + let mut lc = LinkConditioner::new(); + lc.set_packet_loss(1.0); + Some(lc) + }; + + client.set_link_conditioner(link_conditioner); + client + .send(Packet::unreliable(server_addr, b"Hello world!".to_vec())) + .unwrap(); + + let time = Instant::now(); + client.manual_poll(time); + server.manual_poll(time); + + assert_eq!(server.recv().is_none(), true); +} + +#[test] +#[cfg(feature = "tester")] +fn poll_in_thread() { + use std::thread; + let mut server = Socket::bind_any().unwrap(); + let mut client = Socket::bind_any().unwrap(); + let server_addr = server.local_addr().unwrap(); + + // get sender and receiver from server, and start polling in separate thread + let (sender, receiver) = (server.get_packet_sender(), server.get_event_receiver()); + let _thread = thread::spawn(move || server.start_polling()); + + // server will responde to this + client + .send(Packet::reliable_unordered(server_addr, b"Hello!".to_vec())) + .expect("This should send"); + // this will break the loop + client + .send(Packet::reliable_unordered(server_addr, b"Bye!".to_vec())) + .expect("This should send"); + client.manual_poll(Instant::now()); + + // listen for received server messages, and break when "Bye!" is received. + loop { + if let Ok(event) = receiver.recv() { + if let SocketEvent::Packet(packet) = event { + let msg = packet.payload(); + + if msg == b"Bye!" { + break; + } + + sender + .send(Packet::reliable_unordered( + packet.addr(), + b"Hi, there!".to_vec(), + )) + .expect("This should send"); + } + } + } + // loop until we get response from server. + loop { + client.manual_poll(Instant::now()); + if let Some(packet) = client.recv() { + assert_eq!( + packet, + SocketEvent::Packet(Packet::reliable_unordered( + server_addr, + b"Hi, there!".to_vec() + )) + ); + break; + } + } +} diff --git a/tests/common/client.rs b/tests/common/client.rs index ede21af8..741fe896 100644 --- a/tests/common/client.rs +++ b/tests/common/client.rs @@ -1,5 +1,5 @@ -use laminar::{Config, Packet, Socket}; -use log::{error, info}; +use laminar::{Packet, Socket}; +use log::info; use std::net::SocketAddr; use std::thread::{self, JoinHandle}; use std::time::Duration; @@ -38,7 +38,7 @@ impl Client { for _ in 0..packets_to_send { let packet = create_packet(); - socket.send(packet); + socket.send(packet).unwrap(); socket.manual_poll(Instant::now()); let beginning_park = Instant::now(); diff --git a/tests/common/server.rs b/tests/common/server.rs index 411b3e0d..29fcb826 100644 --- a/tests/common/server.rs +++ b/tests/common/server.rs @@ -1,5 +1,5 @@ use crossbeam_channel::{Receiver, Sender, TryIter}; -use laminar::{Config, Packet, Socket, SocketEvent, ThroughputMonitoring}; +use laminar::{Packet, Socket, SocketEvent, ThroughputMonitoring}; use log::error; use std::net::SocketAddr;