diff --git a/src/lockmap.rs b/src/lockmap.rs index 49f24c2..8fa4654 100644 --- a/src/lockmap.rs +++ b/src/lockmap.rs @@ -7,22 +7,78 @@ use std::hash::Hash; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::OnceLock; +/// Internal flags for `State`. +/// +/// This struct wraps an `u32` to store both the reference count and a "has value" flag. +/// The highest bit is used for the flag, and the remaining bits for the reference count. +struct StateFlags(u32); + +impl StateFlags { + const HAS_VALUE_FLAG: u32 = 1 << 31; + const REFCNT_MASK: u32 = !Self::HAS_VALUE_FLAG; + + fn new(refcnt: u32, has_value: bool) -> Self { + let mut val = refcnt & Self::REFCNT_MASK; + if has_value { + val |= Self::HAS_VALUE_FLAG; + } + Self(val) + } + + fn refcnt(&self) -> u32 { + self.0 & Self::REFCNT_MASK + } + + fn has_value(&self) -> bool { + (self.0 & Self::HAS_VALUE_FLAG) != 0 + } + + fn pending_cleanup(&self) -> bool { + self.0 == 0 // equal to `self.refcnt() == 0 && !self.has_value()` + } +} + /// Internal state for a key-value pair in the `LockMap`. /// /// This type manages the stored value, the per-key lock, and a reference count /// used for both synchronization optimization and memory management. struct State { - refcnt: AtomicU32, + flags: AtomicU32, mutex: Mutex, value: UnsafeCell>, } -// SAFETY: `State` is `Sync` if `V` is `Send` because access to the `UnsafeCell>` -// is strictly controlled by the internal `Mutex`. The `refcnt` is an `AtomicU32` which is -// inherently thread-safe. -unsafe impl Sync for State {} - impl State { + fn new(value: Option, refcnt: u32) -> AliasableBox { + AliasableBox::from_unique(Box::new(Self { + flags: AtomicU32::new(StateFlags::new(refcnt, value.is_some()).0), + mutex: Mutex::new(), + value: UnsafeCell::new(value), + })) + } + + fn flags(&self) -> StateFlags { + StateFlags(self.flags.load(Ordering::Acquire)) + } + + fn inc_ref(&self) -> StateFlags { + StateFlags(self.flags.fetch_add(1, Ordering::AcqRel) + 1) + } + + fn dec_ref(&self) -> StateFlags { + StateFlags(self.flags.fetch_sub(1, Ordering::AcqRel) - 1) + } + + fn set_value_state(&self, has_value: bool) { + if has_value { + self.flags + .fetch_or(StateFlags::HAS_VALUE_FLAG, Ordering::Release); + } else { + self.flags + .fetch_and(!StateFlags::HAS_VALUE_FLAG, Ordering::Release); + } + } + /// # Safety /// /// The caller must ensure that the internal `mutex` is locked. @@ -138,16 +194,12 @@ impl LockMap { { let ptr: *mut State = self.map.update(key.clone(), |s| match s { Some(state) => { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); let ptr = &**state as *const State as *mut State; (UpdateAction::Keep, ptr) } None => { - let state = AliasableBox::from_unique(Box::new(State { - refcnt: AtomicU32::new(1), - mutex: Mutex::new(), - value: UnsafeCell::new(None), - })); + let state = State::new(None, 1); let ptr = &*state as *const State as *mut State; (UpdateAction::Replace(state), ptr) } @@ -183,16 +235,12 @@ impl LockMap { { let ptr: *mut State = self.map.update_by_ref(key, |s| match s { Some(state) => { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); let ptr = &**state as *const State as *mut State; (UpdateAction::Keep, ptr) } None => { - let state = AliasableBox::from_unique(Box::new(State { - refcnt: AtomicU32::new(1), - mutex: Mutex::new(), - value: UnsafeCell::new(None), - })); + let state = State::new(None, 1); let ptr = &*state as *const State as *mut State; (UpdateAction::Replace(state), ptr) } @@ -233,14 +281,13 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - // Use Acquire to ensure we see the latest value if refcnt is 0. - if state.refcnt.load(Ordering::Acquire) == 0 { - // SAFETY: We are inside the map's shard lock, and refcnt is 0, + if state.flags().refcnt() == 0 { + // SAFETY: We are inside the map's shard lock, and the reference count is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_ref() }.clone(); (SimpleAction::Keep, value) } else { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); ptr = &**state as *const State as *mut State; (SimpleAction::Keep, None) } @@ -284,24 +331,23 @@ impl LockMap { { let (ptr, value) = self.map.update(key.clone(), move |s| match s { Some(state) => { - // Use Acquire to ensure we see the latest value if refcnt is 0. - if state.refcnt.load(Ordering::Acquire) == 0 { - // SAFETY: We are inside the map's shard lock, and refcnt is 0, + let flags = state.flags(); + if flags.refcnt() == 0 { + // SAFETY: We are inside the map's shard lock, and the reference count is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_mut() }.replace(value); + if !flags.has_value() { + state.set_value_state(true); + } (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); let ptr: *mut State = &**state as *const State as *mut State; (UpdateAction::Keep, (ptr, Some(value))) } } None => { - let state = AliasableBox::from_unique(Box::new(State { - refcnt: AtomicU32::new(0), - mutex: Mutex::new(), - value: UnsafeCell::new(Some(value)), - })); + let state = State::new(Some(value), 0); (UpdateAction::Replace(state), (std::ptr::null_mut(), None)) } }); @@ -343,24 +389,23 @@ impl LockMap { { let (ptr, value) = self.map.update_by_ref(key, move |s| match s { Some(state) => { - // Use Acquire to ensure we see the latest value if refcnt is 0. - if state.refcnt.load(Ordering::Acquire) == 0 { - // SAFETY: We are inside the map's shard lock, and refcnt is 0, + let flags = state.flags(); + if flags.refcnt() == 0 { + // SAFETY: We are inside the map's shard lock, and the reference count is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_mut() }.replace(value); + if !flags.has_value() { + state.set_value_state(true); + } (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); let ptr: *mut State = &**state as *const State as *mut State; (UpdateAction::Keep, (ptr, Some(value))) } } None => { - let state = AliasableBox::from_unique(Box::new(State { - refcnt: AtomicU32::new(0), - mutex: Mutex::new(), - value: UnsafeCell::new(Some(value)), - })); + let state = State::new(Some(value), 0); (UpdateAction::Replace(state), (std::ptr::null_mut(), None)) } }); @@ -403,13 +448,12 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - // Use Acquire to ensure we see the latest value if refcnt is 0. - if state.refcnt.load(Ordering::Acquire) == 0 { - // SAFETY: We are inside the map's shard lock, and refcnt is 0, + if state.flags().refcnt() == 0 { + // SAFETY: We are inside the map's shard lock, and the reference count is 0, // meaning no other thread can be holding an `Entry` for this key. (SimpleAction::Keep, unsafe { state.value_ref() }.is_some()) } else { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); ptr = &**state as *const State as *mut State; (SimpleAction::Keep, false) } @@ -455,14 +499,13 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - // Use Acquire to ensure we see the latest value if refcnt is 0. - if state.refcnt.load(Ordering::Acquire) == 0 { - // SAFETY: We are inside the map's shard lock, and refcnt is 0, + if state.flags().refcnt() == 0 { + // SAFETY: We are inside the map's shard lock, and the reference count is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_mut() }.take(); (SimpleAction::Remove, value) } else { - state.refcnt.fetch_add(1, Ordering::AcqRel); + state.inc_ref(); ptr = &**state as *const State as *mut State; (SimpleAction::Keep, None) } @@ -553,21 +596,19 @@ impl LockMap { { self.map.simple_update(key, |value| match value { Some(state) => { - // SAFETY: We are inside the map's shard lock. If `refcnt` is 0 here, + // SAFETY: We are inside the map's shard lock. If the reference count is 0 here, // then no `Entry` is currently held for this key, and no other thread - // can increment `refcnt` without first acquiring this same shard lock. + // can increment the reference count without first acquiring this same shard lock. // Therefore, if the stored value is also `None`, it is safe to remove // the entry from the map. - if state.refcnt.load(Ordering::Acquire) == 0 - && unsafe { state.value_ref() }.is_none() - { + if state.flags().pending_cleanup() { (SimpleAction::Remove, ()) } else { (SimpleAction::Keep, ()) } } // The key might have been removed by another thread (e.g., via `remove`) - // between the `refcnt` decrement and this call. + // between the reference count decrement and this call. None => (SimpleAction::Keep, ()), }); } @@ -640,16 +681,6 @@ pub struct EntryByVal<'a, K: Eq + Hash, V> { state: *mut State, } -// SAFETY: `EntryByVal` is `Send` if `K` and `V` are `Send`. It holds a raw pointer to `State`, -// which is safe to transfer between threads because the entry is locked and the `State` -// itself is `Sync`. -unsafe impl Send for EntryByVal<'_, K, V> {} - -// SAFETY: `EntryByVal` is `Sync` if `K` is `Sync` and `V` is `Sync`. Multiple threads can -// share a reference to `EntryByVal` safely because all access to the underlying value -// is synchronized by the lock held by the entry. -unsafe impl Sync for EntryByVal<'_, K, V> {} - impl EntryByVal<'_, K, V> { /// Returns a reference to the entry's key. /// @@ -714,7 +745,7 @@ impl EntryByVal<'_, K, V> { /// The value that was stored in the entry, or `None` if the entry was vacant. pub fn remove(&mut self) -> Option { // SAFETY: The entry holds the lock on the `State`, so it is safe to access the value. - unsafe { (*self.state).value_mut() }.take() + self.get_mut().take() } } @@ -729,8 +760,11 @@ impl std::fmt::Debug for Ent impl Drop for EntryByVal<'_, K, V> { fn drop(&mut self) { + // Update flags based on current value state + unsafe { &*self.state }.set_value_state(self.get().is_some()); + // SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it. - unsafe { (*self.state).mutex.unlock() }; + unsafe { &*self.state }.mutex.unlock(); // SAFETY: The pointer `self.state` remains valid here because the `EntryByVal` // incremented the `State`'s reference count when it was created. While `self` is @@ -738,11 +772,9 @@ impl Drop for EntryByVal<'_, K, V> { // `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The // `State` is only deallocated once its reference count reaches zero, which can only // occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access - // `refcnt` is safe at this point. - let prev = (unsafe { &*self.state }) - .refcnt - .fetch_sub(1, Ordering::AcqRel); - if prev == 1 { + // the reference count is safe at this point. + let flags = unsafe { (*self.state).dec_ref() }; + if flags.pending_cleanup() { self.map.try_remove_entry(&self.key); } } @@ -780,28 +812,6 @@ pub struct EntryByRef<'a, 'b, K: Eq + Hash + Borrow, Q: Eq + Hash + ?Sized, V state: *mut State, } -// SAFETY: `EntryByRef` is `Send` if `K`, `Q` and `V` are `Send`. It holds a raw pointer to `State`, -// which is safe to transfer between threads because the entry is locked and the `State` -// itself is `Sync`. -unsafe impl Send for EntryByRef<'_, '_, K, Q, V> -where - K: Eq + Hash + Borrow, - Q: Eq + Hash + ?Sized + Sync, - V: Send, -{ -} - -// SAFETY: `EntryByRef` is `Sync` if `K`, `Q` and `V` are `Sync`. Multiple threads can -// share a reference to `EntryByRef` safely because all access to the underlying value -// is synchronized by the lock held by the entry. -unsafe impl Sync for EntryByRef<'_, '_, K, Q, V> -where - K: Eq + Hash + Borrow, - Q: Eq + Hash + ?Sized + Sync, - V: Sync, -{ -} - impl, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q, V> { /// Returns a reference to the entry's key. /// @@ -886,8 +896,11 @@ where impl, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_, '_, K, Q, V> { fn drop(&mut self) { + // Update flags based on current value state + unsafe { &*self.state }.set_value_state(self.get().is_some()); + // SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it. - unsafe { (*self.state).mutex.unlock() }; + unsafe { &*self.state }.mutex.unlock(); // SAFETY: The pointer `self.state` remains valid here because the `EntryByRef` // incremented the `State`'s reference count when it was created. While `self` is @@ -895,11 +908,9 @@ impl, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_, // `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The // `State` is only deallocated once its reference count reaches zero, which can only // occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access - // `refcnt` is safe at this point. - let prev = (unsafe { &*self.state }) - .refcnt - .fetch_sub(1, Ordering::AcqRel); - if prev == 1 { + // the reference count is safe at this point. + let flags = unsafe { (*self.state).dec_ref() }; + if flags.pending_cleanup() { self.map.try_remove_entry(self.key); } }