Skip to content

Commit

Permalink
core: registry: Allow using the current thread in a new pool.
Browse files Browse the repository at this point in the history
See discussion in #1052.

Closes #1052.
  • Loading branch information
cuviper authored and emilio committed Sep 20, 2023
1 parent dc7090a commit 9461f7b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
13 changes: 13 additions & 0 deletions rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
/// If RAYON_NUM_THREADS is invalid or zero will use the default.
num_threads: usize,

/// The thread we're building *from* will also be part of the pool.
use_current: bool,

/// Custom closure, if any, to handle a panic that we cannot propagate
/// anywhere else.
panic_handler: Option<Box<PanicHandler>>,
Expand Down Expand Up @@ -227,6 +230,7 @@ impl Default for ThreadPoolBuilder {
fn default() -> Self {
ThreadPoolBuilder {
num_threads: 0,
use_current: false,
panic_handler: None,
get_thread_name: None,
stack_size: None,
Expand Down Expand Up @@ -437,6 +441,7 @@ impl<S> ThreadPoolBuilder<S> {
spawn_handler: CustomSpawn::new(spawn),
// ..self
num_threads: self.num_threads,
use_current: self.use_current,
panic_handler: self.panic_handler,
get_thread_name: self.get_thread_name,
stack_size: self.stack_size,
Expand Down Expand Up @@ -529,6 +534,12 @@ impl<S> ThreadPoolBuilder<S> {
self
}

/// Use the current thread as one of the threads in the pool.
pub fn use_current(mut self) -> Self {
self.use_current = true;
self
}

/// Returns a copy of the current panic handler.
fn take_panic_handler(&mut self) -> Option<Box<PanicHandler>> {
self.panic_handler.take()
Expand Down Expand Up @@ -768,6 +779,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let ThreadPoolBuilder {
ref num_threads,
ref use_current,
ref get_thread_name,
ref panic_handler,
ref stack_size,
Expand All @@ -792,6 +804,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {

f.debug_struct("ThreadPoolBuilder")
.field("num_threads", num_threads)
.field("use_current", use_current)
.field("get_thread_name", &get_thread_name)
.field("panic_handler", &panic_handler)
.field("stack_size", &stack_size)
Expand Down
36 changes: 16 additions & 20 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,26 +207,7 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
// is stubbed out, and we won't have to change anything if they do add real threading.
let unsupported = matches!(&result, Err(e) if e.is_unsupported());
if unsupported && WorkerThread::current().is_null() {
let builder = ThreadPoolBuilder::new()
.num_threads(1)
.spawn_handler(|thread| {
// Rather than starting a new thread, we're just taking over the current thread
// *without* running the main loop, so we can still return from here.
// The WorkerThread is leaked, but we never shutdown the global pool anyway.
let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
let registry = &*worker_thread.registry;
let index = worker_thread.index;

unsafe {
WorkerThread::set_current(worker_thread);

// let registry know we are ready to do work
Latch::set(&registry.thread_infos[index].primed);
}

Ok(())
});

let builder = ThreadPoolBuilder::new().num_threads(1).use_current();
let fallback_result = Registry::new(builder);
if fallback_result.is_ok() {
return fallback_result;
Expand Down Expand Up @@ -300,6 +281,21 @@ impl Registry {
stealer,
index,
};

if index == 0 && builder.use_current {
// Rather than starting a new thread, we're just taking over the current thread
// *without* running the main loop, so we can still return from here.
// The WorkerThread is leaked, but we never shutdown the global pool anyway.
// TODO: what about non-global thread pools?
let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));

unsafe {
WorkerThread::set_current(worker_thread);
Latch::set(&registry.thread_infos[index].primed);
}
continue;
}

if let Err(e) = builder.get_spawn_handler().spawn(thread) {
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
}
Expand Down

0 comments on commit 9461f7b

Please sign in to comment.