Skip to content

Commit

Permalink
Re-use threads when calling closures after releasing GIL.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Dec 13, 2023
1 parent 3e2dac8 commit 924ab8c
Showing 1 changed file with 90 additions and 3 deletions.
93 changes: 90 additions & 3 deletions src/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ use crate::{ffi, FromPyPointer, IntoPy, Py, PyObject, PyTypeCheck, PyTypeInfo};
use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::os::raw::c_int;
use std::thread;

/// A marker token that represents holding the GIL.
///
Expand Down Expand Up @@ -316,16 +315,104 @@ impl<'py> Python<'py> {
F: Send + FnOnce() -> T,
T: Send,
{
use std::mem::transmute;
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
use std::sync::mpsc::{sync_channel, SendError, SyncSender};
use std::thread::{spawn, Result};
use std::time::Duration;

use parking_lot::{const_mutex, Mutex};

use crate::impl_::panic::PanicTrap;

// Use a guard pattern to handle reacquiring the GIL,
// so that the GIL will be reacquired even if `f` panics.
// The `Send` bound on the closure prevents the user from
// transferring the `Python` token into the closure.
let _guard = unsafe { SuspendGIL::new() };

// To close soundness loopholes w.r.t. `send_wrapper` or `scoped-tls`,
// we run the closure on a newly created thread so that it cannot
// we run the closure on a separate thread so that it cannot
// access thread-local storage from the current thread.
thread::scope(|s| s.spawn(f).join().unwrap())

// 1. Construct a task
struct Task(*mut dyn FnMut());
unsafe impl Send for Task {}

let (result_sender, result_receiver) = sync_channel::<Result<T>>(0);

let mut f = Some(f);

let mut task = || {
let f = f.take().unwrap();

let result = catch_unwind(AssertUnwindSafe(f));

result_sender.send(result).unwrap();
};

// SAFETY: the current thread will block until the closure has returned
let task = Task(unsafe { transmute(&mut task as &mut dyn FnMut()) });

// 2. Dispatch task to waiting thread, spawn new thread if necessary
let trap = PanicTrap::new(
"allow_threads panicked while stack data was accessed by another thread which is a bug",
);

static THREADS: Mutex<Vec<SyncSender<Task>>> = const_mutex(Vec::new());

enum State {
Pending(Task),
Dispatched(SyncSender<Task>),
}

let mut state = State::Pending(task);

while let Some(task_sender) = THREADS.lock().pop() {
match state {
State::Pending(task) => match task_sender.send(task) {
Ok(()) => {
state = State::Dispatched(task_sender);
break;
}
Err(SendError(task)) => {
state = State::Pending(task);
continue;
}
},
State::Dispatched(_task_sender) => unreachable!(),
}
}

let task_sender = match state {
State::Pending(task) => {
let (task_sender, task_receiver) = sync_channel::<Task>(0);

spawn(move || {
while let Ok(task) = task_receiver.recv_timeout(Duration::from_secs(60)) {
// SAFETY: all data accessed by `task` will stay alive until it completes
unsafe { (*task.0)() };
}
});

task_sender.send(task).unwrap();

task_sender
}
State::Dispatched(task_sender) => task_sender,
};

// 3. Wait for completion and check result
let result = result_receiver.recv().unwrap();

trap.disarm();

THREADS.lock().push(task_sender);

match result {
Ok(result) => result,
Err(payload) => resume_unwind(payload),
}
}

/// Evaluates a Python expression in the given context and returns the result.
Expand Down

0 comments on commit 924ab8c

Please sign in to comment.