diff --git a/library/std/src/sync/reentrant_lock.rs b/library/std/src/sync/reentrant_lock.rs index 80b9e0cf15214..3cf2af2801105 100644 --- a/library/std/src/sync/reentrant_lock.rs +++ b/library/std/src/sync/reentrant_lock.rs @@ -1,12 +1,14 @@ #[cfg(all(test, not(target_os = "emscripten")))] mod tests; +use cfg_if::cfg_if; + use crate::cell::UnsafeCell; use crate::fmt; use crate::ops::Deref; use crate::panic::{RefUnwindSafe, UnwindSafe}; -use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed}; use crate::sys::sync as sys; +use crate::thread::{current_id, ThreadId}; /// A re-entrant mutual exclusion lock /// @@ -53,8 +55,8 @@ use crate::sys::sync as sys; // // The 'owner' field tracks which thread has locked the mutex. // -// We use current_thread_unique_ptr() as the thread identifier, -// which is just the address of a thread local variable. +// We use thread::current_id() as the thread identifier, which is just the +// current thread's ThreadId, so it's unique across the process lifetime. // // If `owner` is set to the identifier of the current thread, // we assume the mutex is already locked and instead of locking it again, @@ -72,14 +74,109 @@ use crate::sys::sync as sys; // since we're not dealing with multiple threads. If it's not equal, // synchronization is left to the mutex, making relaxed memory ordering for // the `owner` field fine in all cases. +// +// On systems without 64 bit atomics we also store the address of a TLS variable +// along the 64-bit TID. We then first check that address against the address +// of that variable on the current thread, and only if they compare equal do we +// compare the actual TIDs. Because we only ever read the TID on the same thread +// that it was written on (or a thread sharing the TLS block with that writer thread), +// we don't need to further synchronize the TID accesses, so they can be regular 64-bit +// non-atomic accesses. #[unstable(feature = "reentrant_lock", issue = "121440")] pub struct ReentrantLock { mutex: sys::Mutex, - owner: AtomicUsize, + owner: Tid, lock_count: UnsafeCell, data: T, } +cfg_if!( + if #[cfg(target_has_atomic = "64")] { + use crate::sync::atomic::{AtomicU64, Ordering::Relaxed}; + + struct Tid(AtomicU64); + + impl Tid { + const fn new() -> Self { + Self(AtomicU64::new(0)) + } + + #[inline] + fn contains(&self, owner: ThreadId) -> bool { + owner.as_u64().get() == self.0.load(Relaxed) + } + + #[inline] + // This is just unsafe to match the API of the Tid type below. + unsafe fn set(&self, tid: Option) { + let value = tid.map_or(0, |tid| tid.as_u64().get()); + self.0.store(value, Relaxed); + } + } + } else { + /// Returns the address of a TLS variable. This is guaranteed to + /// be unique across all currently alive threads. + fn tls_addr() -> usize { + thread_local! { static X: u8 = const { 0u8 } }; + + X.with(|p| <*const u8>::addr(p)) + } + + use crate::sync::atomic::{ + AtomicUsize, + Ordering, + }; + + struct Tid { + // When a thread calls `set()`, this value gets updated to + // the address of a thread local on that thread. This is + // used as a first check in `contains()`; if the `tls_addr` + // doesn't match the TLS address of the current thread, then + // the ThreadId also can't match. Only if the TLS addresses do + // match do we read out the actual TID. + // Note also that we can use relaxed atomic operations here, because + // we only ever read from the tid if `tls_addr` matches the current + // TLS address. In that case, either the the tid has been set by + // the current thread, or by a thread that has terminated before + // the current thread was created. In either case, no further + // synchronization is needed (as per ) + tls_addr: AtomicUsize, + tid: UnsafeCell, + } + + unsafe impl Send for Tid {} + unsafe impl Sync for Tid {} + + impl Tid { + const fn new() -> Self { + Self { tls_addr: AtomicUsize::new(0), tid: UnsafeCell::new(0) } + } + + #[inline] + // NOTE: This assumes that `owner` is the ID of the current + // thread, and may spuriously return `false` if that's not the case. + fn contains(&self, owner: ThreadId) -> bool { + // SAFETY: See the comments in the struct definition. + self.tls_addr.load(Ordering::Relaxed) == tls_addr() + && unsafe { *self.tid.get() } == owner.as_u64().get() + } + + #[inline] + // This may only be called by one thread at a time, and can lead to + // race conditions otherwise. + unsafe fn set(&self, tid: Option) { + // It's important that we set `self.tls_addr` to 0 if the tid is + // cleared. Otherwise, there might be race conditions between + // `set()` and `get()`. + let tls_addr = if tid.is_some() { tls_addr() } else { 0 }; + let value = tid.map_or(0, |tid| tid.as_u64().get()); + self.tls_addr.store(tls_addr, Ordering::Relaxed); + unsafe { *self.tid.get() = value }; + } + } + } +); + #[unstable(feature = "reentrant_lock", issue = "121440")] unsafe impl Send for ReentrantLock {} #[unstable(feature = "reentrant_lock", issue = "121440")] @@ -131,7 +228,7 @@ impl ReentrantLock { pub const fn new(t: T) -> ReentrantLock { ReentrantLock { mutex: sys::Mutex::new(), - owner: AtomicUsize::new(0), + owner: Tid::new(), lock_count: UnsafeCell::new(0), data: t, } @@ -181,14 +278,16 @@ impl ReentrantLock { /// assert_eq!(lock.lock().get(), 10); /// ``` pub fn lock(&self) -> ReentrantLockGuard<'_, T> { - let this_thread = current_thread_unique_ptr(); - // Safety: We only touch lock_count when we own the lock. + let this_thread = current_id(); + // Safety: We only touch lock_count when we own the inner mutex. + // Additionally, we only call `self.owner.set()` while holding + // the inner mutex, so no two threads can call it concurrently. unsafe { - if self.owner.load(Relaxed) == this_thread { + if self.owner.contains(this_thread) { self.increment_lock_count().expect("lock count overflow in reentrant mutex"); } else { self.mutex.lock(); - self.owner.store(this_thread, Relaxed); + self.owner.set(Some(this_thread)); debug_assert_eq!(*self.lock_count.get(), 0); *self.lock_count.get() = 1; } @@ -223,14 +322,16 @@ impl ReentrantLock { /// /// This function does not block. pub(crate) fn try_lock(&self) -> Option> { - let this_thread = current_thread_unique_ptr(); - // Safety: We only touch lock_count when we own the lock. + let this_thread = current_id(); + // Safety: We only touch lock_count when we own the inner mutex. + // Additionally, we only call `self.owner.set()` while holding + // the inner mutex, so no two threads can call it concurrently. unsafe { - if self.owner.load(Relaxed) == this_thread { + if self.owner.contains(this_thread) { self.increment_lock_count()?; Some(ReentrantLockGuard { lock: self }) } else if self.mutex.try_lock() { - self.owner.store(this_thread, Relaxed); + self.owner.set(Some(this_thread)); debug_assert_eq!(*self.lock_count.get(), 0); *self.lock_count.get() = 1; Some(ReentrantLockGuard { lock: self }) @@ -303,18 +404,9 @@ impl Drop for ReentrantLockGuard<'_, T> { unsafe { *self.lock.lock_count.get() -= 1; if *self.lock.lock_count.get() == 0 { - self.lock.owner.store(0, Relaxed); + self.lock.owner.set(None); self.lock.mutex.unlock(); } } } } - -/// Get an address that is unique per running thread. -/// -/// This can be used as a non-null usize-sized ID. -pub(crate) fn current_thread_unique_ptr() -> usize { - // Use a non-drop type to make sure it's still available during thread destruction. - thread_local! { static X: u8 = const { 0 } } - X.with(|x| <*const _>::addr(x)) -} diff --git a/library/std/src/thread/mod.rs b/library/std/src/thread/mod.rs index 22215873933d6..d78b896cd78cd 100644 --- a/library/std/src/thread/mod.rs +++ b/library/std/src/thread/mod.rs @@ -159,7 +159,7 @@ mod tests; use crate::any::Any; -use crate::cell::{OnceCell, UnsafeCell}; +use crate::cell::{Cell, OnceCell, UnsafeCell}; use crate::env; use crate::ffi::{CStr, CString}; use crate::fmt; @@ -698,17 +698,22 @@ where } thread_local! { + // Invariant: `CURRENT` and `CURRENT_ID` will always be initialized together. + // If `CURRENT` is initialized, then `CURRENT_ID` will hold the same value + // as `CURRENT.id()`. static CURRENT: OnceCell = const { OnceCell::new() }; + static CURRENT_ID: Cell> = const { Cell::new(None) }; } /// Sets the thread handle for the current thread. /// /// Aborts if the handle has been set already to reduce code size. pub(crate) fn set_current(thread: Thread) { + let tid = thread.id(); // Using `unwrap` here can add ~3kB to the binary size. We have complete // control over where this is called, so just abort if there is a bug. CURRENT.with(|current| match current.set(thread) { - Ok(()) => {} + Ok(()) => CURRENT_ID.set(Some(tid)), Err(_) => rtabort!("thread::set_current should only be called once per thread"), }); } @@ -718,7 +723,28 @@ pub(crate) fn set_current(thread: Thread) { /// In contrast to the public `current` function, this will not panic if called /// from inside a TLS destructor. pub(crate) fn try_current() -> Option { - CURRENT.try_with(|current| current.get_or_init(|| Thread::new_unnamed()).clone()).ok() + CURRENT + .try_with(|current| { + current + .get_or_init(|| { + let thread = Thread::new_unnamed(); + CURRENT_ID.set(Some(thread.id())); + thread + }) + .clone() + }) + .ok() +} + +/// Gets the id of the thread that invokes it. +#[inline] +pub(crate) fn current_id() -> ThreadId { + CURRENT_ID.get().unwrap_or_else(|| { + // If `CURRENT_ID` isn't initialized yet, then `CURRENT` must also not be initialized. + // `current()` will initialize both `CURRENT` and `CURRENT_ID` so subsequent calls to + // `current_id()` will succeed immediately. + current().id() + }) } /// Gets a handle to the thread that invokes it.