diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 87ebb58..5df0a3d 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -44,3 +44,17 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} slug: SF-Zhou/lockmap files: target/nextest/default/junit.xml + + miri: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install nightly toolchain with Miri + uses: dtolnay/rust-toolchain@nightly + with: + components: miri + + - name: Run Miri tests + run: cargo miri test lockmap::tests diff --git a/Cargo.toml b/Cargo.toml index 661ea23..a51e899 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ description = "A high-performance, thread-safe HashMap implementation for Rust t license = "MIT OR Apache-2.0" [dependencies] +aliasable = "0.1.3" atomic-wait = "1" foldhash = "0.1.5" diff --git a/README.md b/README.md index 1fe5604..80144ab 100644 --- a/README.md +++ b/README.md @@ -39,28 +39,6 @@ keys.insert("key2".to_string()); let mut locked_entries = map.batch_lock::>(keys); ``` -## FAQ - -### Why is the miri test failing? - -Running `cargo miri test` will report a Stacked Borrows violation. This is **expected and intentional**. - -The crate uses a pattern where we obtain a raw pointer to heap-allocated data before moving the `Box` into the internal map. This is done to atomically insert and obtain access to the state in a single lock-protected operation: - -```rust -let mut state: Box<_> = Box::new(State { /* ... */ }); -let ptr = state.as_mut() as *mut State; -(UpdateAction::Replace(state), ptr) // Move Box, return pointer -``` - -While this violates Miri's experimental [Stacked Borrows](https://github.com/rust-lang/unsafe-code-guidelines/blob/master/wip/stacked-borrows.md) model, it is **safe in practice** because: - -1. Moving a `Box` doesn't relocate the heap data—the pointer remains valid -2. The `refcnt` mechanism guarantees exclusive access to the state -3. Extensive concurrent tests validate correctness under heavy contention - -For a detailed explanation, see [#20](https://github.com/SF-Zhou/lockmap/issues/20). - ## License [![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2FSF-Zhou%2Flockmap.svg?type=large)](https://app.fossa.com/projects/git%2Bgithub.com%2FSF-Zhou%2Flockmap?ref=badge_large) diff --git a/src/futex.rs b/src/futex.rs index dd0a6d7..a256207 100644 --- a/src/futex.rs +++ b/src/futex.rs @@ -77,7 +77,7 @@ impl Mutex { /// Acquires the lock, blocking the current thread until it becomes available. /// /// This function will not return until the lock has been acquired. - /// + /// /// # Panics /// /// This function may panic if the current thread already holds the lock. diff --git a/src/lockmap.rs b/src/lockmap.rs index 31b6982..63fddc3 100644 --- a/src/lockmap.rs +++ b/src/lockmap.rs @@ -1,7 +1,10 @@ use crate::{Mutex, ShardsMap, SimpleAction, UpdateAction}; +use aliasable::boxed::AliasableBox; use std::borrow::Borrow; +use std::cell::UnsafeCell; use std::collections::BTreeSet; use std::hash::Hash; +use std::sync::atomic::AtomicU32; use std::sync::OnceLock; /// Internal state for a key-value pair in the `LockMap`. @@ -9,14 +12,36 @@ use std::sync::OnceLock; /// This type manages both the stored value and the queue of waiting threads /// for per-key synchronization. struct State { - refcnt: u32, + refcnt: AtomicU32, mutex: Mutex, - value: Option, + value: UnsafeCell>, +} + +// SAFETY: `State` is `Sync` if `V` is `Send` because access to the `UnsafeCell>` +// is strictly controlled by the internal `Mutex`. The `refcnt` is an `AtomicU32` which is +// inherently thread-safe. +unsafe impl Sync for State {} + +impl State { + /// # Safety + /// + /// The caller must ensure that the internal `mutex` is locked. + unsafe fn value_ref(&self) -> &Option { + &*self.value.get() + } + + /// # Safety + /// + /// The caller must ensure that the internal `mutex` is locked and they have exclusive access. + #[allow(clippy::mut_from_ref)] + unsafe fn value_mut(&self) -> &mut Option { + &mut *self.value.get() + } } /// A thread-safe hashmap that supports locking entries at the key level. pub struct LockMap { - map: ShardsMap>>, + map: ShardsMap>>, } impl Default for LockMap { @@ -113,17 +138,19 @@ impl LockMap { { let ptr: *mut State = self.map.update(key.clone(), |s| match s { Some(state) => { - state.refcnt += 1; - let ptr = state.as_mut() as _; + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let ptr = &**state as *const State as *mut State; (UpdateAction::Keep, ptr) } None => { - let mut state: Box<_> = Box::new(State { - refcnt: 1, + let state = AliasableBox::from_unique(Box::new(State { + refcnt: AtomicU32::new(1), mutex: Mutex::new(), - value: None, - }); - let ptr = state.as_mut() as _; + value: UnsafeCell::new(None), + })); + let ptr = &*state as *const State as *mut State; (UpdateAction::Replace(state), ptr) } }); @@ -158,17 +185,19 @@ impl LockMap { { let ptr: *mut State = self.map.update_by_ref(key, |s| match s { Some(state) => { - state.refcnt += 1; - let ptr = state.as_mut() as _; + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let ptr = &**state as *const State as *mut State; (UpdateAction::Keep, ptr) } None => { - let mut state: Box<_> = Box::new(State { - refcnt: 1, + let state = AliasableBox::from_unique(Box::new(State { + refcnt: AtomicU32::new(1), mutex: Mutex::new(), - value: None, - }); - let ptr = state.as_mut() as _; + value: UnsafeCell::new(None), + })); + let ptr = &*state as *const State as *mut State; (UpdateAction::Replace(state), ptr) } }); @@ -208,12 +237,16 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.refcnt == 0 { - let value = state.value.clone(); + if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // SAFETY: We are inside the map's shard lock, and refcnt is 0, + // meaning no other thread can be holding an `Entry` for this key. + let value = unsafe { state.value_ref() }.clone(); (SimpleAction::Keep, value) } else { - state.refcnt += 1; - ptr = state.as_mut(); + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + ptr = &**state as *const State as *mut State; (SimpleAction::Keep, None) } } @@ -256,21 +289,25 @@ impl LockMap { { let (ptr, value) = self.map.update(key.clone(), move |s| match s { Some(state) => { - if state.refcnt == 0 { - let value = state.value.replace(value); + if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // SAFETY: We are inside the map's shard lock, and refcnt is 0, + // meaning no other thread can be holding an `Entry` for this key. + let value = unsafe { state.value_mut() }.replace(value); (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - state.refcnt += 1; - let ptr: *mut State = state.as_mut(); + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let ptr: *mut State = &**state as *const State as *mut State; (UpdateAction::Keep, (ptr, Some(value))) } } None => { - let state: Box<_> = Box::new(State { - refcnt: 0, + let state = AliasableBox::from_unique(Box::new(State { + refcnt: AtomicU32::new(0), mutex: Mutex::new(), - value: Some(value), - }); + value: UnsafeCell::new(Some(value)), + })); (UpdateAction::Replace(state), (std::ptr::null_mut(), None)) } }); @@ -312,21 +349,25 @@ impl LockMap { { let (ptr, value) = self.map.update_by_ref(key, move |s| match s { Some(state) => { - if state.refcnt == 0 { - let value = state.value.replace(value); + if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // SAFETY: We are inside the map's shard lock, and refcnt is 0, + // meaning no other thread can be holding an `Entry` for this key. + let value = unsafe { state.value_mut() }.replace(value); (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - state.refcnt += 1; - let ptr: *mut State = state.as_mut(); + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let ptr: *mut State = &**state as *const State as *mut State; (UpdateAction::Keep, (ptr, Some(value))) } } None => { - let state: Box<_> = Box::new(State { - refcnt: 0, + let state = AliasableBox::from_unique(Box::new(State { + refcnt: AtomicU32::new(0), mutex: Mutex::new(), - value: Some(value), - }); + value: UnsafeCell::new(Some(value)), + })); (UpdateAction::Replace(state), (std::ptr::null_mut(), None)) } }); @@ -369,11 +410,15 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.refcnt == 0 { - (SimpleAction::Keep, state.value.is_some()) + if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // SAFETY: We are inside the map's shard lock, and refcnt is 0, + // meaning no other thread can be holding an `Entry` for this key. + (SimpleAction::Keep, unsafe { state.value_ref() }.is_some()) } else { - state.refcnt += 1; - ptr = state.as_mut(); + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + ptr = &**state as *const State as *mut State; (SimpleAction::Keep, false) } } @@ -418,12 +463,16 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.refcnt == 0 { - let value = state.value.take(); + if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // SAFETY: We are inside the map's shard lock, and refcnt is 0, + // meaning no other thread can be holding an `Entry` for this key. + let value = unsafe { state.value_mut() }.take(); (SimpleAction::Remove, value) } else { - state.refcnt += 1; - ptr = state.as_mut(); + state + .refcnt + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + ptr = &**state as *const State as *mut State; (SimpleAction::Keep, None) } } @@ -509,8 +558,16 @@ impl LockMap { { self.map.simple_update(key, |value| match value { Some(state) => { - state.refcnt -= 1; - if state.value.is_none() && state.refcnt == 0 { + let prev = state + .refcnt + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + if prev == 1 && unsafe { + // SAFETY: We are inside the map's shard lock, and refcnt was 1 (now 0), + // meaning no other thread can be holding an `Entry` for this key. + state.value_ref() + } + .is_none() + { (SimpleAction::Remove, ()) } else { (SimpleAction::Keep, ()) @@ -521,12 +578,15 @@ impl LockMap { } fn guard_by_val(&self, ptr: *mut State, key: K) -> EntryByVal<'_, K, V> { - let state = unsafe { &mut *ptr }; - state.mutex.lock(); + // SAFETY: The pointer `ptr` is valid because it was just retrieved from the map + // and its reference count was incremented, ensuring it won't be dropped. + // The `AliasableBox` in the map ensures the `State` remains at a stable + // memory location. + unsafe { (*ptr).mutex.lock() }; EntryByVal { map: self, key, - state, + state: ptr, } } @@ -539,12 +599,12 @@ impl LockMap { K: Borrow, Q: Eq + Hash + ?Sized, { - let state = unsafe { &mut *ptr }; - state.mutex.lock(); + // SAFETY: Same as `guard_by_val`. + unsafe { (*ptr).mutex.lock() }; EntryByRef { map: self, key, - state, + state: ptr, } } } @@ -582,9 +642,19 @@ impl std::fmt::Debug for LockMap { pub struct EntryByVal<'a, K: Eq + Hash, V> { map: &'a LockMap, key: K, - state: &'a mut State, + state: *mut State, } +// SAFETY: `EntryByVal` is `Send` if `K` and `V` are `Send`. It holds a raw pointer to `State`, +// which is safe to transfer between threads because the entry is locked and the `State` +// itself is `Sync`. +unsafe impl Send for EntryByVal<'_, K, V> {} + +// SAFETY: `EntryByVal` is `Sync` if `K` is `Sync` and `V` is `Sync`. Multiple threads can +// share a reference to `EntryByVal` safely because all access to the underlying value +// is synchronized by the lock held by the entry. +unsafe impl Sync for EntryByVal<'_, K, V> {} + impl EntryByVal<'_, K, V> { /// Returns a reference to the entry's key. /// @@ -601,7 +671,8 @@ impl EntryByVal<'_, K, V> { /// /// A reference to `Some(V)` if the entry has a value, or `None` if the entry is vacant. pub fn get(&self) -> &Option { - &self.state.value + // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. + unsafe { (*self.state).value_ref() } } /// Returns a mutable reference to the entry's value. @@ -610,7 +681,8 @@ impl EntryByVal<'_, K, V> { /// /// A mutable reference to `Some(V)` if the entry has a value, or `None` if the entry is vacant. pub fn get_mut(&mut self) -> &mut Option { - &mut self.state.value + // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. + unsafe { (*self.state).value_mut() } } /// Sets the value of the entry, returning the old value if it existed. @@ -623,7 +695,7 @@ impl EntryByVal<'_, K, V> { /// /// The previous value if the entry was occupied, or `None` if it was vacant. pub fn insert(&mut self, value: V) -> Option { - self.state.value.replace(value) + self.get_mut().replace(value) } /// Swaps the value of the entry with the provided value. @@ -636,7 +708,7 @@ impl EntryByVal<'_, K, V> { /// /// The previous value of the entry. pub fn swap(&mut self, mut value: Option) -> Option { - std::mem::swap(&mut self.state.value, &mut value); + std::mem::swap(self.get_mut(), &mut value); value } @@ -646,7 +718,8 @@ impl EntryByVal<'_, K, V> { /// /// The value that was stored in the entry, or `None` if the entry was vacant. pub fn remove(&mut self) -> Option { - self.state.value.take() + // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. + unsafe { (*self.state).value_mut() }.take() } } @@ -654,14 +727,15 @@ impl std::fmt::Debug for Ent fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EntryByVal") .field("key", &self.key) - .field("value", &self.state.value) + .field("value", self.get()) .finish() } } impl Drop for EntryByVal<'_, K, V> { fn drop(&mut self) { - self.state.mutex.unlock(); + // SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it. + unsafe { (*self.state).mutex.unlock() }; self.map.unlock(&self.key); } } @@ -695,7 +769,29 @@ impl Drop for EntryByVal<'_, K, V> { pub struct EntryByRef<'a, 'b, K: Eq + Hash + Borrow, Q: Eq + Hash + ?Sized, V> { map: &'a LockMap, key: &'b Q, - state: &'a mut State, + state: *mut State, +} + +// SAFETY: `EntryByRef` is `Send` if `K`, `Q` and `V` are `Send`. It holds a raw pointer to `State`, +// which is safe to transfer between threads because the entry is locked and the `State` +// itself is `Sync`. +unsafe impl Send for EntryByRef<'_, '_, K, Q, V> +where + K: Eq + Hash + Borrow, + Q: Eq + Hash + ?Sized + Sync, + V: Send, +{ +} + +// SAFETY: `EntryByRef` is `Sync` if `K`, `Q` and `V` are `Sync`. Multiple threads can +// share a reference to `EntryByRef` safely because all access to the underlying value +// is synchronized by the lock held by the entry. +unsafe impl Sync for EntryByRef<'_, '_, K, Q, V> +where + K: Eq + Hash + Borrow, + Q: Eq + Hash + ?Sized + Sync, + V: Sync, +{ } impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q, V> { @@ -714,7 +810,8 @@ impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q /// /// A reference to `Some(V)` if the entry has a value, or `None` if the entry is vacant. pub fn get(&self) -> &Option { - &self.state.value + // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. + unsafe { (*self.state).value_ref() } } /// Returns a mutable reference to the entry's value. @@ -723,7 +820,8 @@ impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q /// /// A mutable reference to `Some(V)` if the entry has a value, or `None` if the entry is vacant. pub fn get_mut(&mut self) -> &mut Option { - &mut self.state.value + // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. + unsafe { (*self.state).value_mut() } } /// Sets the value of the entry, returning the old value if it existed. @@ -736,7 +834,7 @@ impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q /// /// The previous value if the entry was occupied, or `None` if it was vacant. pub fn insert(&mut self, value: V) -> Option { - self.state.value.replace(value) + self.get_mut().replace(value) } /// Swaps the value of the entry with the provided value. @@ -749,7 +847,7 @@ impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q /// /// The previous value of the entry. pub fn swap(&mut self, mut value: Option) -> Option { - std::mem::swap(&mut self.state.value, &mut value); + std::mem::swap(self.get_mut(), &mut value); value } @@ -759,7 +857,8 @@ impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q /// /// The value that was stored in the entry, or `None` if the entry was vacant. pub fn remove(&mut self) -> Option { - self.state.value.take() + // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. + self.get_mut().take() } } @@ -772,14 +871,15 @@ where fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EntryByRef") .field("key", &self.key) - .field("value", &self.state.value) + .field("value", self.get()) .finish() } } impl, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_, '_, K, Q, V> { fn drop(&mut self) { - self.state.mutex.unlock(); + // SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it. + unsafe { (*self.state).mutex.unlock() }; self.map.unlock(self.key); } } @@ -861,15 +961,15 @@ mod tests { #[should_panic(expected = "impossible: unlock a non-existent key!")] fn test_lockmap_invalid_unlock() { let map = LockMap::::new(); - let mut state = State { - refcnt: 1, + let state = State { + refcnt: AtomicU32::new(1), mutex: Mutex::new(), - value: None, + value: UnsafeCell::new(None), }; let _ = EntryByVal { map: &map, key: 7268, - state: &mut state, + state: &state as *const State as *mut State, }; } @@ -877,7 +977,10 @@ mod tests { fn test_lockmap_same_key_by_value() { let lock_map = Arc::new(LockMap::::with_capacity(256)); let current = Arc::new(AtomicU32::default()); + #[cfg(not(miri))] const N: usize = 1 << 20; + #[cfg(miri)] + const N: usize = 1 << 6; const M: usize = 4; const S: usize = 0; @@ -911,7 +1014,10 @@ mod tests { fn test_lockmap_same_key_by_ref() { let lock_map = Arc::new(LockMap::::with_capacity(256)); let current = Arc::new(AtomicU32::default()); + #[cfg(not(miri))] const N: usize = 1 << 20; + #[cfg(miri)] + const N: usize = 1 << 6; const M: usize = 4; const S: &str = "hello"; @@ -947,7 +1053,10 @@ mod tests { fn test_lockmap_random_key() { let lock_map = Arc::new(LockMap::::with_capacity_and_shard_amount(256, 16)); let total = Arc::new(AtomicU32::default()); + #[cfg(not(miri))] const N: usize = 1 << 12; + #[cfg(miri)] + const N: usize = 1 << 6; const M: usize = 8; let threads = (0..M) @@ -975,7 +1084,10 @@ mod tests { fn test_lockmap_random_batch_lock() { let lock_map = Arc::new(LockMap::::with_capacity_and_shard_amount(256, 16)); let total = Arc::new(AtomicU32::default()); + #[cfg(not(miri))] const N: usize = 1 << 16; + #[cfg(miri)] + const N: usize = 1 << 6; const M: usize = 8; let threads = (0..M) @@ -986,12 +1098,12 @@ mod tests { for _ in 0..N { let keys = (0..3).map(|_| rand::random::() % 32).collect(); let mut entries: HashMap<_, _> = lock_map.batch_lock(keys); - for (_key, entry) in &mut entries { + for entry in entries.values_mut() { assert!(entry.get().is_none()); entry.insert(1); } total.fetch_add(1, Ordering::AcqRel); - for (_key, entry) in &mut entries { + for entry in entries.values_mut() { entry.remove(); } } @@ -1006,7 +1118,10 @@ mod tests { #[test] fn test_lockmap_get_set() { let lock_map = Arc::new(LockMap::::with_capacity_and_shard_amount(256, 16)); + #[cfg(not(miri))] const N: usize = 1 << 20; + #[cfg(miri)] + const N: usize = 1 << 6; let entry_thread = { let lock_map = lock_map.clone(); @@ -1062,7 +1177,10 @@ mod tests { let lock_map = Arc::new(LockMap::::with_capacity_and_shard_amount( 256, 16, )); + #[cfg(not(miri))] const N: usize = 1 << 18; + #[cfg(miri)] + const N: usize = 1 << 6; let entry_thread = { let lock_map = lock_map.clone(); @@ -1116,8 +1234,14 @@ mod tests { #[test] fn test_lockmap_heavy_contention() { let lock_map = Arc::new(LockMap::::new()); + #[cfg(not(miri))] const THREADS: usize = 16; + #[cfg(miri)] + const THREADS: usize = 4; + #[cfg(not(miri))] const OPS_PER_THREAD: usize = 10000; + #[cfg(miri)] + const OPS_PER_THREAD: usize = 10; const HOT_KEYS: u32 = 5; let counter = Arc::new(AtomicU32::new(0)); diff --git a/src/shards_map.rs b/src/shards_map.rs index dd111a3..be4b674 100644 --- a/src/shards_map.rs +++ b/src/shards_map.rs @@ -83,7 +83,7 @@ where /// # Arguments /// /// * `key` - The key to update - /// * `func` - A function that takes an `Option<&mut V>` and returns a tuple containing + /// * `func` - A function that takes an `Option<&mut V>` and returns a tuple containing /// the action to take (`SimpleAction::Keep` or `SimpleAction::Remove`) and a result value /// /// # Returns