Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 110 additions & 99 deletions src/lockmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,78 @@ use std::hash::Hash;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::OnceLock;

/// Internal flags for `State<V>`.
///
/// This struct wraps an `u32` to store both the reference count and a "has value" flag.
/// The highest bit is used for the flag, and the remaining bits for the reference count.
struct StateFlags(u32);

impl StateFlags {
const HAS_VALUE_FLAG: u32 = 1 << 31;
const REFCNT_MASK: u32 = !Self::HAS_VALUE_FLAG;

fn new(refcnt: u32, has_value: bool) -> Self {
let mut val = refcnt & Self::REFCNT_MASK;
if has_value {
val |= Self::HAS_VALUE_FLAG;
}
Self(val)
}

fn refcnt(&self) -> u32 {
self.0 & Self::REFCNT_MASK
}

fn has_value(&self) -> bool {
(self.0 & Self::HAS_VALUE_FLAG) != 0
}

fn pending_cleanup(&self) -> bool {
self.0 == 0 // equal to `self.refcnt() == 0 && !self.has_value()`
}
}

/// Internal state for a key-value pair in the `LockMap`.
///
/// This type manages the stored value, the per-key lock, and a reference count
/// used for both synchronization optimization and memory management.
struct State<V> {
refcnt: AtomicU32,
flags: AtomicU32,
mutex: Mutex,
value: UnsafeCell<Option<V>>,
}

// SAFETY: `State<V>` is `Sync` if `V` is `Send` because access to the `UnsafeCell<Option<V>>`
// is strictly controlled by the internal `Mutex`. The `refcnt` is an `AtomicU32` which is
// inherently thread-safe.
unsafe impl<V: Send> Sync for State<V> {}

impl<V> State<V> {
fn new(value: Option<V>, refcnt: u32) -> AliasableBox<Self> {
AliasableBox::from_unique(Box::new(Self {
flags: AtomicU32::new(StateFlags::new(refcnt, value.is_some()).0),
mutex: Mutex::new(),
value: UnsafeCell::new(value),
}))
}

fn flags(&self) -> StateFlags {
StateFlags(self.flags.load(Ordering::Acquire))
}

fn inc_ref(&self) -> StateFlags {
StateFlags(self.flags.fetch_add(1, Ordering::AcqRel) + 1)
}

fn dec_ref(&self) -> StateFlags {
StateFlags(self.flags.fetch_sub(1, Ordering::AcqRel) - 1)
}

fn set_value_state(&self, has_value: bool) {
if has_value {
self.flags
.fetch_or(StateFlags::HAS_VALUE_FLAG, Ordering::Release);
} else {
self.flags
.fetch_and(!StateFlags::HAS_VALUE_FLAG, Ordering::Release);
}
}

/// # Safety
///
/// The caller must ensure that the internal `mutex` is locked.
Expand Down Expand Up @@ -138,16 +194,12 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let ptr: *mut State<V> = self.map.update(key.clone(), |s| match s {
Some(state) => {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
let ptr = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, ptr)
}
None => {
let state = AliasableBox::from_unique(Box::new(State {
refcnt: AtomicU32::new(1),
mutex: Mutex::new(),
value: UnsafeCell::new(None),
}));
let state = State::new(None, 1);
let ptr = &*state as *const State<V> as *mut State<V>;
(UpdateAction::Replace(state), ptr)
}
Expand Down Expand Up @@ -183,16 +235,12 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let ptr: *mut State<V> = self.map.update_by_ref(key, |s| match s {
Some(state) => {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
let ptr = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, ptr)
}
None => {
let state = AliasableBox::from_unique(Box::new(State {
refcnt: AtomicU32::new(1),
mutex: Mutex::new(),
value: UnsafeCell::new(None),
}));
let state = State::new(None, 1);
let ptr = &*state as *const State<V> as *mut State<V>;
(UpdateAction::Replace(state), ptr)
}
Expand Down Expand Up @@ -233,14 +281,13 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
let mut ptr: *mut State<V> = std::ptr::null_mut();
let value = self.map.simple_update(key, |s| match s {
Some(state) => {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
if state.flags().refcnt() == 0 {
// SAFETY: We are inside the map's shard lock, and the reference count is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_ref() }.clone();
(SimpleAction::Keep, value)
} else {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
ptr = &**state as *const State<V> as *mut State<V>;
(SimpleAction::Keep, None)
}
Expand Down Expand Up @@ -284,24 +331,23 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let (ptr, value) = self.map.update(key.clone(), move |s| match s {
Some(state) => {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
let flags = state.flags();
if flags.refcnt() == 0 {
// SAFETY: We are inside the map's shard lock, and the reference count is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_mut() }.replace(value);
if !flags.has_value() {
state.set_value_state(true);
}
(UpdateAction::Keep, (std::ptr::null_mut(), value))
} else {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
let ptr: *mut State<V> = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, (ptr, Some(value)))
}
}
None => {
let state = AliasableBox::from_unique(Box::new(State {
refcnt: AtomicU32::new(0),
mutex: Mutex::new(),
value: UnsafeCell::new(Some(value)),
}));
let state = State::new(Some(value), 0);
(UpdateAction::Replace(state), (std::ptr::null_mut(), None))
}
});
Expand Down Expand Up @@ -343,24 +389,23 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let (ptr, value) = self.map.update_by_ref(key, move |s| match s {
Some(state) => {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
let flags = state.flags();
if flags.refcnt() == 0 {
// SAFETY: We are inside the map's shard lock, and the reference count is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_mut() }.replace(value);
if !flags.has_value() {
state.set_value_state(true);
}
(UpdateAction::Keep, (std::ptr::null_mut(), value))
} else {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
let ptr: *mut State<V> = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, (ptr, Some(value)))
}
}
None => {
let state = AliasableBox::from_unique(Box::new(State {
refcnt: AtomicU32::new(0),
mutex: Mutex::new(),
value: UnsafeCell::new(Some(value)),
}));
let state = State::new(Some(value), 0);
(UpdateAction::Replace(state), (std::ptr::null_mut(), None))
}
});
Expand Down Expand Up @@ -403,13 +448,12 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
let mut ptr: *mut State<V> = std::ptr::null_mut();
let value = self.map.simple_update(key, |s| match s {
Some(state) => {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
if state.flags().refcnt() == 0 {
// SAFETY: We are inside the map's shard lock, and the reference count is 0,
// meaning no other thread can be holding an `Entry` for this key.
(SimpleAction::Keep, unsafe { state.value_ref() }.is_some())
} else {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
ptr = &**state as *const State<V> as *mut State<V>;
(SimpleAction::Keep, false)
}
Expand Down Expand Up @@ -455,14 +499,13 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
let mut ptr: *mut State<V> = std::ptr::null_mut();
let value = self.map.simple_update(key, |s| match s {
Some(state) => {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
if state.flags().refcnt() == 0 {
// SAFETY: We are inside the map's shard lock, and the reference count is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_mut() }.take();
(SimpleAction::Remove, value)
} else {
state.refcnt.fetch_add(1, Ordering::AcqRel);
state.inc_ref();
ptr = &**state as *const State<V> as *mut State<V>;
(SimpleAction::Keep, None)
}
Expand Down Expand Up @@ -553,21 +596,19 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
self.map.simple_update(key, |value| match value {
Some(state) => {
// SAFETY: We are inside the map's shard lock. If `refcnt` is 0 here,
// SAFETY: We are inside the map's shard lock. If the reference count is 0 here,
// then no `Entry` is currently held for this key, and no other thread
// can increment `refcnt` without first acquiring this same shard lock.
// can increment the reference count without first acquiring this same shard lock.
// Therefore, if the stored value is also `None`, it is safe to remove
// the entry from the map.
if state.refcnt.load(Ordering::Acquire) == 0
&& unsafe { state.value_ref() }.is_none()
{
if state.flags().pending_cleanup() {
(SimpleAction::Remove, ())
} else {
(SimpleAction::Keep, ())
}
}
// The key might have been removed by another thread (e.g., via `remove`)
// between the `refcnt` decrement and this call.
// between the reference count decrement and this call.
None => (SimpleAction::Keep, ()),
});
}
Expand Down Expand Up @@ -640,16 +681,6 @@ pub struct EntryByVal<'a, K: Eq + Hash, V> {
state: *mut State<V>,
}

// SAFETY: `EntryByVal` is `Send` if `K` and `V` are `Send`. It holds a raw pointer to `State<V>`,
// which is safe to transfer between threads because the entry is locked and the `State`
// itself is `Sync`.
unsafe impl<K: Eq + Hash + Send, V: Send> Send for EntryByVal<'_, K, V> {}

// SAFETY: `EntryByVal` is `Sync` if `K` is `Sync` and `V` is `Sync`. Multiple threads can
// share a reference to `EntryByVal` safely because all access to the underlying value
// is synchronized by the lock held by the entry.
unsafe impl<K: Eq + Hash + Sync, V: Sync> Sync for EntryByVal<'_, K, V> {}

impl<K: Eq + Hash, V> EntryByVal<'_, K, V> {
/// Returns a reference to the entry's key.
///
Expand Down Expand Up @@ -714,7 +745,7 @@ impl<K: Eq + Hash, V> EntryByVal<'_, K, V> {
/// The value that was stored in the entry, or `None` if the entry was vacant.
pub fn remove(&mut self) -> Option<V> {
// SAFETY: The entry holds the lock on the `State`, so it is safe to access the value.
unsafe { (*self.state).value_mut() }.take()
self.get_mut().take()
}
}

Expand All @@ -729,20 +760,21 @@ impl<K: Eq + Hash + std::fmt::Debug, V: std::fmt::Debug> std::fmt::Debug for Ent

impl<K: Eq + Hash, V> Drop for EntryByVal<'_, K, V> {
fn drop(&mut self) {
// Update flags based on current value state
unsafe { &*self.state }.set_value_state(self.get().is_some());

// SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it.
unsafe { (*self.state).mutex.unlock() };
unsafe { &*self.state }.mutex.unlock();

// SAFETY: The pointer `self.state` remains valid here because the `EntryByVal`
// incremented the `State`'s reference count when it was created. While `self` is
// alive in this `drop` call, the reference count is therefore at least 1, and this
// `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The
// `State` is only deallocated once its reference count reaches zero, which can only
// occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access
// `refcnt` is safe at this point.
let prev = (unsafe { &*self.state })
.refcnt
.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
// the reference count is safe at this point.
let flags = unsafe { (*self.state).dec_ref() };
if flags.pending_cleanup() {
self.map.try_remove_entry(&self.key);
}
}
Expand Down Expand Up @@ -780,28 +812,6 @@ pub struct EntryByRef<'a, 'b, K: Eq + Hash + Borrow<Q>, Q: Eq + Hash + ?Sized, V
state: *mut State<V>,
}

// SAFETY: `EntryByRef` is `Send` if `K`, `Q` and `V` are `Send`. It holds a raw pointer to `State<V>`,
// which is safe to transfer between threads because the entry is locked and the `State`
// itself is `Sync`.
unsafe impl<K, Q, V> Send for EntryByRef<'_, '_, K, Q, V>
where
K: Eq + Hash + Borrow<Q>,
Q: Eq + Hash + ?Sized + Sync,
V: Send,
{
}

// SAFETY: `EntryByRef` is `Sync` if `K`, `Q` and `V` are `Sync`. Multiple threads can
// share a reference to `EntryByRef` safely because all access to the underlying value
// is synchronized by the lock held by the entry.
unsafe impl<K, Q, V> Sync for EntryByRef<'_, '_, K, Q, V>
where
K: Eq + Hash + Borrow<Q>,
Q: Eq + Hash + ?Sized + Sync,
V: Sync,
{
}

impl<K: Eq + Hash + Borrow<Q>, Q: Eq + Hash + ?Sized, V> EntryByRef<'_, '_, K, Q, V> {
/// Returns a reference to the entry's key.
///
Expand Down Expand Up @@ -886,20 +896,21 @@ where

impl<K: Eq + Hash + Borrow<Q>, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_, '_, K, Q, V> {
fn drop(&mut self) {
// Update flags based on current value state
unsafe { &*self.state }.set_value_state(self.get().is_some());

// SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it.
unsafe { (*self.state).mutex.unlock() };
unsafe { &*self.state }.mutex.unlock();

// SAFETY: The pointer `self.state` remains valid here because the `EntryByRef`
// incremented the `State`'s reference count when it was created. While `self` is
// alive in this `drop` call, the reference count is therefore at least 1, and this
// `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The
// `State` is only deallocated once its reference count reaches zero, which can only
// occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access
// `refcnt` is safe at this point.
let prev = (unsafe { &*self.state })
.refcnt
.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
// the reference count is safe at this point.
let flags = unsafe { (*self.state).dec_ref() };
if flags.pending_cleanup() {
self.map.try_remove_entry(self.key);
}
}
Expand Down
Loading