From a46e1f89b3c9e08adba513da162f23383b74cccd Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Wed, 18 Jan 2023 16:34:32 -0800 Subject: [PATCH 1/2] Use pointers instead of `&self` in `Latch::set` `Latch::set` can invalidate its own `&self`, because it releases the owning thread to continue execution, which may then invalidate the latch by deallocation, reuse, etc. We've known about this problem when it comes to accessing latch fields too late, but the possibly dangling reference was still a problem, like rust-lang/rust#55005. The result of that was rust-lang/rust#98017, omitting the LLVM attribute `dereferenceable` on references to `!Freeze` types -- those containing `UnsafeCell`. However, miri's Stacked Borrows implementation is finer- grained than that, only relaxing for the cell itself in the `!Freeze` type. For rayon, that solves the dangling reference in atomic calls, but remains a problem for other fields of a `Latch`. This easiest fix for rayon is to use a raw pointer instead of `&self`. We still end up with some temporary references for stuff like atomics, but those should be fine with the rules above. --- rayon-core/src/job.rs | 2 +- rayon-core/src/latch.rs | 57 +++++++++++++++++++++---------------- rayon-core/src/registry.rs | 6 ++-- rayon-core/src/scope/mod.rs | 46 +++++++++++++++--------------- 4 files changed, 60 insertions(+), 51 deletions(-) diff --git a/rayon-core/src/job.rs b/rayon-core/src/job.rs index b7a3dae18..deccebc1e 100644 --- a/rayon-core/src/job.rs +++ b/rayon-core/src/job.rs @@ -112,7 +112,7 @@ where let abort = unwind::AbortIfPanic; let func = (*this.func.get()).take().unwrap(); (*this.result.get()) = JobResult::call(func); - this.latch.set(); + Latch::set(&this.latch); mem::forget(abort); } } diff --git a/rayon-core/src/latch.rs b/rayon-core/src/latch.rs index 090929374..b1be1e0cf 100644 --- a/rayon-core/src/latch.rs +++ b/rayon-core/src/latch.rs @@ -37,10 +37,15 @@ pub(super) trait Latch { /// /// Setting a latch triggers other threads to wake up and (in some /// cases) complete. This may, in turn, cause memory to be - /// allocated and so forth. One must be very careful about this, + /// deallocated and so forth. One must be very careful about this, /// and it's typically better to read all the fields you will need /// to access *before* a latch is set! - fn set(&self); + /// + /// This function operates on `*const Self` instead of `&self` to allow it + /// to become dangling during this call. The caller must ensure that the + /// pointer is valid upon entry, and not invalidated during the call by any + /// actions other than `set` itself. + unsafe fn set(this: *const Self); } pub(super) trait AsCoreLatch { @@ -123,8 +128,8 @@ impl CoreLatch { /// doing some wakeups; those are encapsulated in the surrounding /// latch code. #[inline] - fn set(&self) -> bool { - let old_state = self.state.swap(SET, Ordering::AcqRel); + unsafe fn set(this: *const Self) -> bool { + let old_state = (*this).state.swap(SET, Ordering::AcqRel); old_state == SLEEPING } @@ -186,16 +191,16 @@ impl<'r> AsCoreLatch for SpinLatch<'r> { impl<'r> Latch for SpinLatch<'r> { #[inline] - fn set(&self) { + unsafe fn set(this: *const Self) { let cross_registry; - let registry: &Registry = if self.cross { + let registry: &Registry = if (*this).cross { // Ensure the registry stays alive while we notify it. // Otherwise, it would be possible that we set the spin // latch and the other thread sees it and exits, causing // the registry to be deallocated, all before we get a // chance to invoke `registry.notify_worker_latch_is_set`. - cross_registry = Arc::clone(self.registry); + cross_registry = Arc::clone((*this).registry); &cross_registry } else { // If this is not a "cross-registry" spin-latch, then the @@ -203,12 +208,12 @@ impl<'r> Latch for SpinLatch<'r> { // that the registry stays alive. However, that doesn't // include this *particular* `Arc` handle if the waiting // thread then exits, so we must completely dereference it. - self.registry + (*this).registry }; - let target_worker_index = self.target_worker_index; + let target_worker_index = (*this).target_worker_index; - // NOTE: Once we `set`, the target may proceed and invalidate `&self`! - if self.core_latch.set() { + // NOTE: Once we `set`, the target may proceed and invalidate `this`! + if CoreLatch::set(&(*this).core_latch) { // Subtle: at this point, we can no longer read from // `self`, because the thread owning this spin latch may // have awoken and deallocated the latch. Therefore, we @@ -255,10 +260,10 @@ impl LockLatch { impl Latch for LockLatch { #[inline] - fn set(&self) { - let mut guard = self.m.lock().unwrap(); + unsafe fn set(this: *const Self) { + let mut guard = (*this).m.lock().unwrap(); *guard = true; - self.v.notify_all(); + (*this).v.notify_all(); } } @@ -307,9 +312,9 @@ impl CountLatch { /// count, then the latch is **set**, and calls to `probe()` will /// return true. Returns whether the latch was set. #[inline] - pub(super) fn set(&self) -> bool { - if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 { - self.core_latch.set(); + pub(super) unsafe fn set(this: *const Self) -> bool { + if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 { + CoreLatch::set(&(*this).core_latch); true } else { false @@ -320,8 +325,12 @@ impl CountLatch { /// the latch is set, then the specific worker thread is tickled, /// which should be the one that owns this latch. #[inline] - pub(super) fn set_and_tickle_one(&self, registry: &Registry, target_worker_index: usize) { - if self.set() { + pub(super) unsafe fn set_and_tickle_one( + this: *const Self, + registry: &Registry, + target_worker_index: usize, + ) { + if Self::set(this) { registry.notify_worker_latch_is_set(target_worker_index); } } @@ -362,9 +371,9 @@ impl CountLockLatch { impl Latch for CountLockLatch { #[inline] - fn set(&self) { - if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 { - self.lock_latch.set(); + unsafe fn set(this: *const Self) { + if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 { + LockLatch::set(&(*this).lock_latch); } } } @@ -374,7 +383,7 @@ where L: Latch, { #[inline] - fn set(&self) { - L::set(self); + unsafe fn set(this: *const Self) { + L::set(&**this); } } diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index 279e298d2..33dc42e0d 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -575,7 +575,7 @@ impl Registry { pub(super) fn terminate(&self) { if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 { for (i, thread_info) in self.thread_infos.iter().enumerate() { - thread_info.terminate.set_and_tickle_one(self, i); + unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) }; } } } @@ -869,7 +869,7 @@ unsafe fn main_loop( let registry = &*worker_thread.registry; // let registry know we are ready to do work - registry.thread_infos[index].primed.set(); + Latch::set(®istry.thread_infos[index].primed); // Worker threads should not panic. If they do, just abort, as the // internal state of the threadpool is corrupted. Note that if @@ -892,7 +892,7 @@ unsafe fn main_loop( debug_assert!(worker_thread.take_local_job().is_none()); // let registry know we are done - registry.thread_infos[index].stopped.set(); + Latch::set(®istry.thread_infos[index].stopped); // Normal termination, do not abort. mem::forget(abort_guard); diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index be3e7c314..b014cf09e 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -540,10 +540,10 @@ impl<'scope> Scope<'scope> { BODY: FnOnce(&Scope<'scope>) + Send + 'scope, { let scope_ptr = ScopePtr(self); - let job = HeapJob::new(move || { + let job = HeapJob::new(move || unsafe { // SAFETY: this job will execute before the scope ends. - let scope = unsafe { scope_ptr.as_ref() }; - scope.base.execute_job(move || body(scope)) + let scope = scope_ptr.as_ref(); + ScopeBase::execute_job(&scope.base, move || body(scope)) }); let job_ref = self.base.heap_job_ref(job); @@ -562,12 +562,12 @@ impl<'scope> Scope<'scope> { BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, { let scope_ptr = ScopePtr(self); - let job = ArcJob::new(move || { + let job = ArcJob::new(move || unsafe { // SAFETY: this job will execute before the scope ends. - let scope = unsafe { scope_ptr.as_ref() }; + let scope = scope_ptr.as_ref(); let body = &body; let func = move || BroadcastContext::with(move |ctx| body(scope, ctx)); - scope.base.execute_job(func); + ScopeBase::execute_job(&scope.base, func) }); self.base.inject_broadcast(job) } @@ -600,10 +600,10 @@ impl<'scope> ScopeFifo<'scope> { BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope, { let scope_ptr = ScopePtr(self); - let job = HeapJob::new(move || { + let job = HeapJob::new(move || unsafe { // SAFETY: this job will execute before the scope ends. - let scope = unsafe { scope_ptr.as_ref() }; - scope.base.execute_job(move || body(scope)) + let scope = scope_ptr.as_ref(); + ScopeBase::execute_job(&scope.base, move || body(scope)) }); let job_ref = self.base.heap_job_ref(job); @@ -628,12 +628,12 @@ impl<'scope> ScopeFifo<'scope> { BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, { let scope_ptr = ScopePtr(self); - let job = ArcJob::new(move || { + let job = ArcJob::new(move || unsafe { // SAFETY: this job will execute before the scope ends. - let scope = unsafe { scope_ptr.as_ref() }; + let scope = scope_ptr.as_ref(); let body = &body; let func = move || BroadcastContext::with(move |ctx| body(scope, ctx)); - scope.base.execute_job(func); + ScopeBase::execute_job(&scope.base, func) }); self.base.inject_broadcast(job) } @@ -688,7 +688,7 @@ impl<'scope> ScopeBase<'scope> { where FUNC: FnOnce() -> R, { - let result = self.execute_job_closure(func); + let result = unsafe { Self::execute_job_closure(self, func) }; self.job_completed_latch.wait(owner); self.maybe_propagate_panic(); result.unwrap() // only None if `op` panicked, and that would have been propagated @@ -696,28 +696,28 @@ impl<'scope> ScopeBase<'scope> { /// Executes `func` as a job, either aborting or executing as /// appropriate. - fn execute_job(&self, func: FUNC) + unsafe fn execute_job(this: *const Self, func: FUNC) where FUNC: FnOnce(), { - let _: Option<()> = self.execute_job_closure(func); + let _: Option<()> = Self::execute_job_closure(this, func); } /// Executes `func` as a job in scope. Adjusts the "job completed" /// counters and also catches any panic and stores it into /// `scope`. - fn execute_job_closure(&self, func: FUNC) -> Option + unsafe fn execute_job_closure(this: *const Self, func: FUNC) -> Option where FUNC: FnOnce() -> R, { match unwind::halt_unwinding(func) { Ok(r) => { - self.job_completed_latch.set(); + Latch::set(&(*this).job_completed_latch); Some(r) } Err(err) => { - self.job_panicked(err); - self.job_completed_latch.set(); + (*this).job_panicked(err); + Latch::set(&(*this).job_completed_latch); None } } @@ -797,14 +797,14 @@ impl ScopeLatch { } impl Latch for ScopeLatch { - fn set(&self) { - match self { + unsafe fn set(this: *const Self) { + match &*this { ScopeLatch::Stealing { latch, registry, worker_index, - } => latch.set_and_tickle_one(registry, *worker_index), - ScopeLatch::Blocking { latch } => latch.set(), + } => CountLatch::set_and_tickle_one(latch, registry, *worker_index), + ScopeLatch::Blocking { latch } => Latch::set(latch), } } } From f880d02decba8e5e39cd302b35c013b9d69a7166 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Thu, 19 Jan 2023 18:09:27 -0800 Subject: [PATCH 2/2] Add a virtual wrapper for &Latch --- rayon-core/src/broadcast/mod.rs | 5 ++++- rayon-core/src/latch.rs | 35 ++++++++++++++++++++++++++++----- rayon-core/src/registry.rs | 4 ++-- 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs index 452aa71b6..d991c5461 100644 --- a/rayon-core/src/broadcast/mod.rs +++ b/rayon-core/src/broadcast/mod.rs @@ -1,4 +1,5 @@ use crate::job::{ArcJob, StackJob}; +use crate::latch::LatchRef; use crate::registry::{Registry, WorkerThread}; use crate::scope::ScopeLatch; use std::fmt; @@ -107,7 +108,9 @@ where let n_threads = registry.num_threads(); let current_thread = WorkerThread::current().as_ref(); let latch = ScopeLatch::with_count(n_threads, current_thread); - let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect(); + let jobs: Vec<_> = (0..n_threads) + .map(|_| StackJob::new(&f, LatchRef::new(&latch))) + .collect(); let job_refs = jobs.iter().map(|job| job.as_job_ref()); registry.inject_broadcast(job_refs); diff --git a/rayon-core/src/latch.rs b/rayon-core/src/latch.rs index b1be1e0cf..de4327234 100644 --- a/rayon-core/src/latch.rs +++ b/rayon-core/src/latch.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; +use std::ops::Deref; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Condvar, Mutex}; use std::usize; @@ -378,12 +380,35 @@ impl Latch for CountLockLatch { } } -impl<'a, L> Latch for &'a L -where - L: Latch, -{ +/// `&L` without any implication of `dereferenceable` for `Latch::set` +pub(super) struct LatchRef<'a, L> { + inner: *const L, + marker: PhantomData<&'a L>, +} + +impl LatchRef<'_, L> { + pub(super) fn new(inner: &L) -> LatchRef<'_, L> { + LatchRef { + inner, + marker: PhantomData, + } + } +} + +unsafe impl Sync for LatchRef<'_, L> {} + +impl Deref for LatchRef<'_, L> { + type Target = L; + + fn deref(&self) -> &L { + // SAFETY: if we have &self, the inner latch is still alive + unsafe { &*self.inner } + } +} + +impl Latch for LatchRef<'_, L> { #[inline] unsafe fn set(this: *const Self) { - L::set(&**this); + L::set((*this).inner); } } diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index 33dc42e0d..24c0855c6 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -1,5 +1,5 @@ use crate::job::{JobFifo, JobRef, StackJob}; -use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LockLatch, SpinLatch}; +use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatch, SpinLatch}; use crate::log::Event::*; use crate::log::Logger; use crate::sleep::Sleep; @@ -505,7 +505,7 @@ impl Registry { assert!(injected && !worker_thread.is_null()); op(&*worker_thread, true) }, - l, + LatchRef::new(l), ); self.inject(&[job.as_job_ref()]); job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.