diff --git a/Cargo.toml b/Cargo.toml index 8046937..8bf23da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,4 +39,4 @@ path = "benches/rtt.rs" [[bench]] name = "rx" harness = false -path = "benches/rx.rs" \ No newline at end of file +path = "benches/rx.rs" diff --git a/examples/common/mod.rs b/examples/common/mod.rs index f853591..3e9c36c 100644 --- a/examples/common/mod.rs +++ b/examples/common/mod.rs @@ -1,4 +1,5 @@ -use bcast::RingBuffer; +use anyhow::anyhow; +use bcast::{error::Error, RingBuffer}; use rand::{thread_rng, Rng}; use std::mem::MaybeUninit; @@ -28,7 +29,17 @@ pub fn reader(bytes: &[u8]) -> anyhow::Result<()> { let mut count = 0; if let Some(batch) = reader.read_batch() { for msg in batch { - let msg = msg?; + let msg = match msg { + Ok(msg) => msg, + Err(Error::Overrun(position)) => { + println!("overrun for {} bytes, resetting reader", position); + reader.reset(); + break; + } + Err(e) => { + return Err(anyhow!(e)); + } + }; let mut payload = unsafe { MaybeUninit::new([0u8; 1024]).assume_init() }; msg.read(&mut payload)?; #[cfg(debug_assertions)] diff --git a/src/lib.rs b/src/lib.rs index e9695cc..5aa5aa2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,6 +68,9 @@ pub const METADATA_BUFFER_SIZE: usize = 1024; pub const USER_DEFINED_NULL_VALUE: u32 = 0; const FRAME_HEADER_MSG_LEN_MASK: u32 = 0x0FFFFFFF; +// represents the max value we can encode on the frame header for the payload length +const MAX_PAYLOAD_LEN: usize = (1 << 31) - 1; + /// Ring buffer header that contains producer position. The position is expressed in bytes and /// will always increment. #[derive(Debug)] @@ -236,9 +239,6 @@ impl RingBuffer { assert!(bytes.len() > size_of::
(), "insufficient size for the header"); assert!((bytes.len() - size_of::
()).is_power_of_two(), "buffer len must be power of two"); - // represents the max value we can encode on the frame header for the payload length - const MAX_PAYLOAD_LEN: usize = (1 << 31) - 1; - let header = bytes.as_ptr() as *mut Header; let capacity = bytes.len() - size_of::
(); Self { @@ -283,6 +283,17 @@ impl RingBuffer { Writer { ring: self, position } } + /// Will consume `self` and return instance of writer backed by this ring buffer, with the + /// initial position set to the provided value. + #[cfg(test)] + pub fn into_writer_at(self, position: usize) -> Writer { + assert!(get_aligned_size(position) == position, "position must be aligned"); + self.header().producer_position.store(position, Ordering::SeqCst); + // mark as initialised after setting the position + self.header().ready.store(true, Ordering::SeqCst); + Writer { ring: self, position } + } + /// Will consume `self` and return instance of writer backed by this ring buffer. This /// method also accepts closure to populate `metadata` buffer. pub fn into_writer_with_metadata(mut self, metadata: F) -> Writer { @@ -431,7 +442,7 @@ impl<'a> Claim<'a> { let header = writer.frame_header_mut(); header.fields = fields; header.user_defined = USER_DEFINED_NULL_VALUE; - writer.position += padding_len + size_of::(); + writer.position = writer.position.wrapping_add(padding_len + size_of::()); } let position_snapshot = writer.position; @@ -490,7 +501,7 @@ impl<'a> Claim<'a> { header.user_defined = self.user_defined; // advance writer position - self.writer.position += self.len + size_of::(); + self.writer.position = self.writer.position.wrapping_add(self.len + size_of::()); // signal updated producer position self.writer @@ -523,6 +534,7 @@ impl Reader { /// Set reader initial position (the default is producer current position). pub const fn with_initial_position(self, position: usize) -> Self { + assert!(get_aligned_size(position) == position, "position must be aligned"); Self { ring: self.ring, position: Cell::new(position), @@ -541,13 +553,20 @@ impl Reader { self.position.get() & (self.ring.capacity - 1) } + /// Reset reader position to current producer position, for recovering from overrun. + #[cold] + pub fn reset(&self) { + self.position + .set(self.ring.header().producer_position.load(Ordering::Acquire)); + } + /// Construct `Batch` object that can efficiently read multiple messages in a batch between /// `Reader` current position and prevailing producer position. Returns `None` if there is /// no new data to read. #[inline] pub fn read_batch(&self) -> Option { let producer_position = self.ring.header().producer_position.load(Ordering::Acquire); - let limit = producer_position - self.position.get(); + let limit = producer_position.wrapping_sub(self.position.get()); if limit == 0 { return None; } @@ -564,15 +583,15 @@ impl Reader { pub fn receive_next(&self) -> Option> { let producer_position_before = self.ring.header().producer_position.load(Ordering::Acquire); // no new messages - if producer_position_before - self.position.get() == 0 { + if producer_position_before.wrapping_sub(self.position.get()) == 0 { return None; } // attempt to receive next frame // if the frame is padding will skip it and attempt to return next frame - match self.receive_next_impl(producer_position_before) { + match self.receive_next_impl(self.position.get()) { Some(msg) => match msg { Ok(msg) if !msg.is_padding => Some(Ok(msg)), - Ok(_) => self.receive_next_impl(producer_position_before), + Ok(_) => self.receive_next_impl(self.position.get()), Err(err) => Some(Err(err)), }, None => None, @@ -580,22 +599,17 @@ impl Reader { } #[inline] - fn receive_next_impl(&self, producer_position_before: usize) -> Option> { + fn receive_next_impl(&self, reader_position: usize) -> Option> { // extract frame header fields let frame_header = self.as_frame_header(); let (is_fin, is_continuation, is_padding, length) = frame_header.unpack_fields(); let user_defined = frame_header.user_defined; let producer_position_after = self.ring.header().producer_position.load(Ordering::Acquire); - // ensure we have not been overrun - if producer_position_after > producer_position_before + self.ring.capacity { - return Some(Err(Error::overrun(self.position.get()))); - } - // construct the massage let message = Message { header: self.ring.header(), - position: self.position.get() + size_of::(), + position: self.position.get().wrapping_add(size_of::()), payload_len: length as usize, capacity: self.ring.capacity, is_fin, @@ -604,12 +618,17 @@ impl Reader { user_defined, }; + // ensure we have not been overrun by the writer + // so the frame header is not overwritten and can be trusted + if producer_position_after.wrapping_sub(reader_position) > self.ring.capacity { + return Some(Err(Error::overrun(reader_position))); + } + // update reader position let aligned_payload_len = get_aligned_size(message.payload_len); let position = self.position.get(); self.position - .set(position + aligned_payload_len + size_of::()); - + .set(position.wrapping_add(aligned_payload_len + size_of::())); Some(Ok(message)) } } @@ -658,6 +677,11 @@ impl Message { /// ``` #[inline] pub fn read(&self, buf: &mut [u8]) -> Result { + assert!( + self.payload_len <= min(self.capacity / 2 - size_of::(), MAX_PAYLOAD_LEN), + "payload size is greater than mtu" + ); + assert!(self.index() + self.payload_len <= self.capacity, "payload overshots ring buffer"); // ensure destination buffer is of sufficient size if self.payload_len > buf.len() { return Err(Error::insufficient_buffer_size(buf.len(), self.payload_len)); @@ -671,7 +695,7 @@ impl Message { let producer_position_after = self.header.producer_position.load(Ordering::Acquire); // ensure we have not been overrun by the producer - if producer_position_after > producer_position_before + self.capacity { + if producer_position_after.wrapping_sub(producer_position_before) > self.capacity { return Err(Error::overrun(self.position)); } @@ -751,6 +775,7 @@ impl Iterator for BatchIter<'_> { mod tests { use super::*; use crate::error::Error; + use rand::{thread_rng, Rng}; use std::ptr::addr_of; use std::sync::atomic::Ordering::SeqCst; @@ -909,6 +934,21 @@ mod tests { writer.claim(16, true).commit(); writer.claim(16, true).commit(); + let msg = reader.receive_next().unwrap(); + assert!(matches!(msg.unwrap_err(), Error::Overrun(_))); + } + + #[test] + fn should_overrun_read_batch() { + let bytes = [0u8; HEADER_SIZE + 64]; + let mut writer = RingBuffer::new(&bytes).into_writer(); + let reader = RingBuffer::new(&bytes).into_reader(); + + writer.claim(16, true).commit(); + writer.claim(16, true).commit(); + writer.claim(16, true).commit(); + writer.claim(16, true).commit(); + let mut iter = reader.read_batch().unwrap().into_iter(); let msg = iter.next().unwrap(); assert!(matches!(msg.unwrap_err(), Error::Overrun(_))); @@ -1026,8 +1066,9 @@ mod tests { let data_addr = writer.ring.header().data_ptr() as usize; assert_eq!(METADATA_BUFFER_SIZE, data_addr - metadata_addr); - assert_eq!(128, metadata_addr - ready_addr); - assert_eq!(128, ready_addr - producer_position_addr); + // 128 for x86_64, 64 for x86 + assert_eq!(align_of::>(), metadata_addr - ready_addr); + assert_eq!(align_of::>(), ready_addr - producer_position_addr); let header_ptr = writer.ring.header() as *const Header; let data_ptr = writer.ring.header().data_ptr(); @@ -1039,7 +1080,10 @@ mod tests { assert_eq!(size_of::
() + size_of::(), buf_ptr_0 as usize - header_ptr as usize); assert_eq!(size_of::(), buf_ptr_0 as usize - data_ptr as usize); assert_eq!(16, claim.get_buffer().len()); + #[cfg(target_arch = "x86_64")] assert_eq!(1280, size_of::
()); + #[cfg(target_arch = "x86")] + assert_eq!(1152, size_of::
()); assert_eq!(8, size_of::()); assert_eq!(8, align_of::()); assert_eq!(size_of::
(), frame_ptr_0 as usize - bytes.as_ptr() as usize); @@ -1249,4 +1293,62 @@ mod tests { assert_eq!(104, reader.receive_next().unwrap().unwrap().user_defined); assert_eq!(105, reader.receive_next().unwrap().unwrap().user_defined); } + + #[test] + fn should_position_wrap_around() { + let bytes = [0u8; HEADER_SIZE + 2048]; + let mut writer = RingBuffer::new(&bytes).into_writer_at(usize::MAX - 1023); + // last claim before wrap around + writer.claim_with_user_defined(1000, true, 100).commit(); + assert_eq!(usize::MAX - 15, writer.position); + // first claim after wrap around, will insert padding frame and + // continue from position zero + writer.claim_with_user_defined(128, true, 101).commit(); + assert_eq!(136, writer.position); + // a normal claim after wrap around + writer.claim_with_user_defined(16, true, 102).commit(); + assert_eq!(160, writer.position); + // verify we got all the messages + let reader = RingBuffer::new(&bytes) + .into_reader() + .with_initial_position(usize::MAX - 1023); + assert_eq!(100, reader.receive_next().unwrap().unwrap().user_defined); + assert_eq!(101, reader.receive_next().unwrap().unwrap().user_defined); + assert_eq!(102, reader.receive_next().unwrap().unwrap().user_defined); + // and are still in sync + assert_eq!(160, reader.position.get()); + } + + #[test] + fn should_position_wrap_around_and_overrun_reader() { + let bytes = [0u8; HEADER_SIZE + 2048]; + let mut writer = RingBuffer::new(&bytes).into_writer_at(usize::MAX - 2047); + let reader = RingBuffer::new(&bytes).into_reader(); + + // First claim and read + writer.claim_with_user_defined(1000, true, 100).commit(); + assert_eq!(100, reader.receive_next().unwrap().unwrap().user_defined); + + // Last claim before wrap around + writer.claim_with_user_defined(1000, true, 101).commit(); + + // First claim after wrap around + writer.claim_with_user_defined(512, true, 102).commit(); + + // Overrun the reader and overwrite the header frame the reader will read + let mut claim = writer.claim_with_user_defined(1000, true, 103); + thread_rng().fill(claim.get_buffer_mut()); + claim.commit(); + + assert!(matches!(reader.receive_next().unwrap().unwrap_err(), Error::Overrun(_))); + // Reset the reader and start over + reader.reset(); + assert!(reader.receive_next().is_none()); + // Continue writing and reading + assert_eq!(reader.position.get(), writer.position); + + writer.claim_with_user_defined(1000, true, 104).commit(); + + assert_eq!(104, reader.receive_next().unwrap().unwrap().user_defined); + } }