diff --git a/src/marker.rs b/src/marker.rs index 54940fea7ed..6ea80dc2307 100644 --- a/src/marker.rs +++ b/src/marker.rs @@ -53,7 +53,6 @@ use crate::{ffi, FromPyPointer, IntoPy, Py, PyObject, PyTypeCheck, PyTypeInfo}; use std::ffi::{CStr, CString}; use std::marker::PhantomData; use std::os::raw::c_int; -use std::thread; /// A marker token that represents holding the GIL. /// @@ -316,6 +315,13 @@ impl<'py> Python<'py> { F: Send + FnOnce() -> T, T: Send, { + use parking_lot::{const_mutex, Mutex}; + use std::mem::{transmute, ManuallyDrop, MaybeUninit}; + use std::panic::{catch_unwind, AssertUnwindSafe}; + use std::sync::mpsc::{sync_channel, SendError, SyncSender}; + use std::thread::{spawn, Result}; + use std::time::Duration; + // Use a guard pattern to handle reacquiring the GIL, // so that the GIL will be reacquired even if `f` panics. // The `Send` bound on the closure prevents the user from @@ -323,9 +329,80 @@ impl<'py> Python<'py> { let _guard = unsafe { SuspendGIL::new() }; // To close soundness loopholes w.r.t. `send_wrapper` or `scoped-tls`, - // we run the closure on a newly created thread so that it cannot + // we run the closure on a separate thread so that it cannot // access thread-local storage from the current thread. - thread::scope(|s| s.spawn(f).join().unwrap()) + + // Construct a task + struct Task(*mut dyn FnMut()); + unsafe impl Send for Task {} + + let mut f = ManuallyDrop::new(f); + let mut result = MaybeUninit::>::uninit(); + + let (result_sender, result_receiver) = sync_channel(0); + + let mut task = || { + // SAFETY: `F` is `Send` and we ensure that this closure is called at most once + let f = unsafe { ManuallyDrop::take(&mut f) }; + + result.write(catch_unwind(AssertUnwindSafe(f))); + + let _ = result_sender.send(()); + }; + // SAFETY: the current thread will block until the closure has returned + let task = Task(unsafe { transmute(&mut task as &mut dyn FnMut()) }); + + // Enqueue task and spawn thread if necessary + static THREADS: Mutex>> = const_mutex(Vec::new()); + + enum State { + Pending(Task), + Dispatched(SyncSender), + } + + let mut state = State::Pending(task); + + while let Some(task_sender) = THREADS.lock().pop() { + match state { + State::Pending(task) => match task_sender.send(task) { + Ok(()) => { + state = State::Dispatched(task_sender); + break; + } + Err(SendError(task)) => { + state = State::Pending(task); + continue; + } + }, + State::Dispatched(_sender) => unreachable!(), + } + } + + let task_sender = match state { + State::Pending(task) => { + let (task_sender, task_receiver) = sync_channel::(0); + + spawn(move || { + while let Ok(task) = task_receiver.recv_timeout(Duration::from_secs(60)) { + // SAFETY: all data accessed by `task` will stay alive until it completes + unsafe { (*task.0)() }; + } + }); + + task_sender.send(task).unwrap(); + + task_sender + } + State::Dispatched(task_sender) => task_sender, + }; + + // Wait for completion and read result + result_receiver.recv().unwrap(); + + THREADS.lock().push(task_sender); + + // SAFETY: the task completed and hence initialized `result` + unsafe { result.assume_init().unwrap() } } /// Evaluates a Python expression in the given context and returns the result.