diff --git a/src/lockmap.rs b/src/lockmap.rs index 63fddc3..49f24c2 100644 --- a/src/lockmap.rs +++ b/src/lockmap.rs @@ -4,13 +4,13 @@ use std::borrow::Borrow; use std::cell::UnsafeCell; use std::collections::BTreeSet; use std::hash::Hash; -use std::sync::atomic::AtomicU32; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::OnceLock; /// Internal state for a key-value pair in the `LockMap`. /// -/// This type manages both the stored value and the queue of waiting threads -/// for per-key synchronization. +/// This type manages the stored value, the per-key lock, and a reference count +/// used for both synchronization optimization and memory management. struct State { refcnt: AtomicU32, mutex: Mutex, @@ -138,9 +138,7 @@ impl LockMap { { let ptr: *mut State = self.map.update(key.clone(), |s| match s { Some(state) => { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); let ptr = &**state as *const State as *mut State; (UpdateAction::Keep, ptr) } @@ -185,9 +183,7 @@ impl LockMap { { let ptr: *mut State = self.map.update_by_ref(key, |s| match s { Some(state) => { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); let ptr = &**state as *const State as *mut State; (UpdateAction::Keep, ptr) } @@ -237,15 +233,14 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // Use Acquire to ensure we see the latest value if refcnt is 0. + if state.refcnt.load(Ordering::Acquire) == 0 { // SAFETY: We are inside the map's shard lock, and refcnt is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_ref() }.clone(); (SimpleAction::Keep, value) } else { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); ptr = &**state as *const State as *mut State; (SimpleAction::Keep, None) } @@ -289,15 +284,14 @@ impl LockMap { { let (ptr, value) = self.map.update(key.clone(), move |s| match s { Some(state) => { - if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // Use Acquire to ensure we see the latest value if refcnt is 0. + if state.refcnt.load(Ordering::Acquire) == 0 { // SAFETY: We are inside the map's shard lock, and refcnt is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_mut() }.replace(value); (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); let ptr: *mut State = &**state as *const State as *mut State; (UpdateAction::Keep, (ptr, Some(value))) } @@ -349,15 +343,14 @@ impl LockMap { { let (ptr, value) = self.map.update_by_ref(key, move |s| match s { Some(state) => { - if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // Use Acquire to ensure we see the latest value if refcnt is 0. + if state.refcnt.load(Ordering::Acquire) == 0 { // SAFETY: We are inside the map's shard lock, and refcnt is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_mut() }.replace(value); (UpdateAction::Keep, (std::ptr::null_mut(), value)) } else { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); let ptr: *mut State = &**state as *const State as *mut State; (UpdateAction::Keep, (ptr, Some(value))) } @@ -410,14 +403,13 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // Use Acquire to ensure we see the latest value if refcnt is 0. + if state.refcnt.load(Ordering::Acquire) == 0 { // SAFETY: We are inside the map's shard lock, and refcnt is 0, // meaning no other thread can be holding an `Entry` for this key. (SimpleAction::Keep, unsafe { state.value_ref() }.is_some()) } else { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); ptr = &**state as *const State as *mut State; (SimpleAction::Keep, false) } @@ -463,15 +455,14 @@ impl LockMap { let mut ptr: *mut State = std::ptr::null_mut(); let value = self.map.simple_update(key, |s| match s { Some(state) => { - if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 { + // Use Acquire to ensure we see the latest value if refcnt is 0. + if state.refcnt.load(Ordering::Acquire) == 0 { // SAFETY: We are inside the map's shard lock, and refcnt is 0, // meaning no other thread can be holding an `Entry` for this key. let value = unsafe { state.value_mut() }.take(); (SimpleAction::Remove, value) } else { - state - .refcnt - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + state.refcnt.fetch_add(1, Ordering::AcqRel); ptr = &**state as *const State as *mut State; (SimpleAction::Keep, None) } @@ -551,29 +542,33 @@ impl LockMap { .collect() } - fn unlock(&self, key: &Q) + /// Attempts to remove an entry from the map if it's no longer needed. + /// + /// An entry is considered no longer needed if its reference count is 0 + /// and it contains no value. + fn try_remove_entry(&self, key: &Q) where K: Borrow, Q: Eq + Hash + ?Sized, { self.map.simple_update(key, |value| match value { Some(state) => { - let prev = state - .refcnt - .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); - if prev == 1 && unsafe { - // SAFETY: We are inside the map's shard lock, and refcnt was 1 (now 0), - // meaning no other thread can be holding an `Entry` for this key. - state.value_ref() - } - .is_none() + // SAFETY: We are inside the map's shard lock. If `refcnt` is 0 here, + // then no `Entry` is currently held for this key, and no other thread + // can increment `refcnt` without first acquiring this same shard lock. + // Therefore, if the stored value is also `None`, it is safe to remove + // the entry from the map. + if state.refcnt.load(Ordering::Acquire) == 0 + && unsafe { state.value_ref() }.is_none() { (SimpleAction::Remove, ()) } else { (SimpleAction::Keep, ()) } } - None => panic!("impossible: unlock a non-existent key!"), + // The key might have been removed by another thread (e.g., via `remove`) + // between the `refcnt` decrement and this call. + None => (SimpleAction::Keep, ()), }); } @@ -736,7 +731,20 @@ impl Drop for EntryByVal<'_, K, V> { fn drop(&mut self) { // SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it. unsafe { (*self.state).mutex.unlock() }; - self.map.unlock(&self.key); + + // SAFETY: The pointer `self.state` remains valid here because the `EntryByVal` + // incremented the `State`'s reference count when it was created. While `self` is + // alive in this `drop` call, the reference count is therefore at least 1, and this + // `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The + // `State` is only deallocated once its reference count reaches zero, which can only + // occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access + // `refcnt` is safe at this point. + let prev = (unsafe { &*self.state }) + .refcnt + .fetch_sub(1, Ordering::AcqRel); + if prev == 1 { + self.map.try_remove_entry(&self.key); + } } } @@ -880,7 +888,20 @@ impl, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_, fn drop(&mut self) { // SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it. unsafe { (*self.state).mutex.unlock() }; - self.map.unlock(self.key); + + // SAFETY: The pointer `self.state` remains valid here because the `EntryByRef` + // incremented the `State`'s reference count when it was created. While `self` is + // alive in this `drop` call, the reference count is therefore at least 1, and this + // `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The + // `State` is only deallocated once its reference count reaches zero, which can only + // occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access + // `refcnt` is safe at this point. + let prev = (unsafe { &*self.state }) + .refcnt + .fetch_sub(1, Ordering::AcqRel); + if prev == 1 { + self.map.try_remove_entry(self.key); + } } } @@ -957,22 +978,6 @@ mod tests { } } - #[test] - #[should_panic(expected = "impossible: unlock a non-existent key!")] - fn test_lockmap_invalid_unlock() { - let map = LockMap::::new(); - let state = State { - refcnt: AtomicU32::new(1), - mutex: Mutex::new(), - value: UnsafeCell::new(None), - }; - let _ = EntryByVal { - map: &map, - key: 7268, - state: &state as *const State as *mut State, - }; - } - #[test] fn test_lockmap_same_key_by_value() { let lock_map = Arc::new(LockMap::::with_capacity(256));