From 924ab8caf1e9f5f7d6f5bb85c3b035387d3da1c9 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Tue, 12 Dec 2023 20:33:08 +0100 Subject: [PATCH] Re-use threads when calling closures after releasing GIL. --- src/marker.rs | 93 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/src/marker.rs b/src/marker.rs index 54940fea7ed..d0fad90ebea 100644 --- a/src/marker.rs +++ b/src/marker.rs @@ -53,7 +53,6 @@ use crate::{ffi, FromPyPointer, IntoPy, Py, PyObject, PyTypeCheck, PyTypeInfo}; use std::ffi::{CStr, CString}; use std::marker::PhantomData; use std::os::raw::c_int; -use std::thread; /// A marker token that represents holding the GIL. /// @@ -316,6 +315,16 @@ impl<'py> Python<'py> { F: Send + FnOnce() -> T, T: Send, { + use std::mem::transmute; + use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; + use std::sync::mpsc::{sync_channel, SendError, SyncSender}; + use std::thread::{spawn, Result}; + use std::time::Duration; + + use parking_lot::{const_mutex, Mutex}; + + use crate::impl_::panic::PanicTrap; + // Use a guard pattern to handle reacquiring the GIL, // so that the GIL will be reacquired even if `f` panics. // The `Send` bound on the closure prevents the user from @@ -323,9 +332,87 @@ impl<'py> Python<'py> { let _guard = unsafe { SuspendGIL::new() }; // To close soundness loopholes w.r.t. `send_wrapper` or `scoped-tls`, - // we run the closure on a newly created thread so that it cannot + // we run the closure on a separate thread so that it cannot // access thread-local storage from the current thread. - thread::scope(|s| s.spawn(f).join().unwrap()) + + // 1. Construct a task + struct Task(*mut dyn FnMut()); + unsafe impl Send for Task {} + + let (result_sender, result_receiver) = sync_channel::>(0); + + let mut f = Some(f); + + let mut task = || { + let f = f.take().unwrap(); + + let result = catch_unwind(AssertUnwindSafe(f)); + + result_sender.send(result).unwrap(); + }; + + // SAFETY: the current thread will block until the closure has returned + let task = Task(unsafe { transmute(&mut task as &mut dyn FnMut()) }); + + // 2. Dispatch task to waiting thread, spawn new thread if necessary + let trap = PanicTrap::new( + "allow_threads panicked while stack data was accessed by another thread which is a bug", + ); + + static THREADS: Mutex>> = const_mutex(Vec::new()); + + enum State { + Pending(Task), + Dispatched(SyncSender), + } + + let mut state = State::Pending(task); + + while let Some(task_sender) = THREADS.lock().pop() { + match state { + State::Pending(task) => match task_sender.send(task) { + Ok(()) => { + state = State::Dispatched(task_sender); + break; + } + Err(SendError(task)) => { + state = State::Pending(task); + continue; + } + }, + State::Dispatched(_task_sender) => unreachable!(), + } + } + + let task_sender = match state { + State::Pending(task) => { + let (task_sender, task_receiver) = sync_channel::(0); + + spawn(move || { + while let Ok(task) = task_receiver.recv_timeout(Duration::from_secs(60)) { + // SAFETY: all data accessed by `task` will stay alive until it completes + unsafe { (*task.0)() }; + } + }); + + task_sender.send(task).unwrap(); + + task_sender + } + State::Dispatched(task_sender) => task_sender, + }; + + // 3. Wait for completion and check result + let result = result_receiver.recv().unwrap(); + + trap.disarm(); + + THREADS.lock().push(task_sender); + + match result { + Ok(result) => result, + Err(payload) => resume_unwind(payload), + } } /// Evaluates a Python expression in the given context and returns the result.