Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Tasks functional on WASM #13889

Merged
merged 15 commits into from
Jul 16, 2024
1 change: 1 addition & 0 deletions crates/bevy_tasks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ concurrent-queue = { version = "2.0.0", optional = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen-futures = "0.4"
pin-project = "1"

[dev-dependencies]
web-time = { version = "1.1" }
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use task_pool::{Scope, TaskPool, TaskPoolBuilder};
#[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))]
mod single_threaded_task_pool;
#[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))]
pub use single_threaded_task_pool::{FakeTask, Scope, TaskPool, TaskPoolBuilder, ThreadExecutor};
pub use single_threaded_task_pool::{LocalTask, Scope, TaskPool, TaskPoolBuilder, ThreadExecutor};

mod usages;
#[cfg(not(target_arch = "wasm32"))]
Expand Down
138 changes: 120 additions & 18 deletions crates/bevy_tasks/src/single_threaded_task_pool.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
use std::sync::Arc;
use std::{cell::RefCell, future::Future, marker::PhantomData, mem, rc::Rc};

#[cfg(target_arch = "wasm32")]
use std::{
any::Any,
cell::Cell,
panic::{AssertUnwindSafe, UnwindSafe},
task::Poll,
task::Waker,
};

thread_local! {
static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = const { async_executor::LocalExecutor::new() };
}
Expand Down Expand Up @@ -145,34 +154,50 @@ impl TaskPool {
.collect()
}

/// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
/// cancelled and "detached" allowing it to continue running without having to be polled by the
/// Spawns a static future onto the thread pool. The returned Task is a future, which can be polled
/// to retrieve the output of the original future. Dropped the task will attempt to cancel it.
JoJoJet marked this conversation as resolved.
Show resolved Hide resolved
/// It can also be "detached", allowing it to continue running without having to be polled by the
/// end-user.
///
/// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should be used instead.
pub fn spawn<T>(&self, future: impl Future<Output = T> + 'static) -> FakeTask
pub fn spawn<T>(&self, future: impl Future<Output = T> + 'static) -> LocalTask<T>
where
T: 'static,
{
#[cfg(target_arch = "wasm32")]
wasm_bindgen_futures::spawn_local(async move {
future.await;
});
{
let task = LocalTask::new();
let sender = task.0.clone();
wasm_bindgen_futures::spawn_local(async move {
// Catch any panics that occur when polling the future so they can
// be propagated back to the task handle.
let value = CatchUnwind(AssertUnwindSafe(future)).await;
// Store the value in the task. If the task handle has been dropped,
// then the value will also get dropped at the end of the scope when the
// inner task's reference count drops to zero.
sender.value.set(Some(value));
// Wake up any tasks waiting on this future
if let Some(waker) = sender.waker.take() {
waker.wake();
}
});
return task;
}

#[cfg(not(target_arch = "wasm32"))]
{
LOCAL_EXECUTOR.with(|executor| {
let _task = executor.spawn(future);
let task = executor.spawn(future);
// Loop until all tasks are done
while executor.try_tick() {}
});
}

FakeTask
LocalTask::new(task)
})
}
}

/// Spawns a static future on the JS event loop. This is exactly the same as [`TaskPool::spawn`].
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> FakeTask
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> LocalTask<T>
where
T: 'static,
{
Expand All @@ -198,15 +223,92 @@ impl TaskPool {
}
}

/// An empty task used in single-threaded contexts.
///
/// This does nothing and is therefore safe, and recommended, to ignore.
#[derive(Debug)]
pub struct FakeTask;
/// A handle to a task running on this thread. Can be awaited to obtain the output of the task.
#[cfg(not(target_arch = "wasm32"))]
pub type LocalTask<T> = crate::Task<T>;

impl FakeTask {
/// No op on the single threaded task pool
/// A handle to a task running on this thread. Can be awaited to obtain the output of the task.
#[cfg(target_arch = "wasm32")]
pub struct LocalTask<T>(Rc<InnerTask<T>>);

#[cfg(target_arch = "wasm32")]
type Panic = Box<dyn Any + Send>;

#[cfg(target_arch = "wasm32")]
struct InnerTask<T> {
value: Cell<Option<Result<T, Panic>>>,
waker: Cell<Option<Waker>>,
}
JoJoJet marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(target_arch = "wasm32")]
impl<T> LocalTask<T> {
fn new() -> Self {
let task = InnerTask {
value: Cell::new(None),
waker: Cell::new(None),
};
Self(Rc::new(task))
}
}

#[cfg(target_arch = "wasm32")]
impl<T> LocalTask<T> {
/// Allows the task to continue running independently of the current context.
/// This is a no-op.
pub fn detach(self) {}

/// Waits for the task to stop running.
///
/// In single-threaded contexts, cancellation does not work, so this method is
/// identical to just awaiting the task. This method only exists for parity with
/// the multi-threaded variant of the task pool.
pub async fn cancel(self) -> Option<T> {
Some(self.await)
}

/// Returns `true` if the current task is finished.
///
/// Unlike poll, it doesn't resolve the final value, it just checks if the task has finished.
pub fn is_finished(&self) -> bool {
let value = self.0.value.take();
let is_finished = value.is_some();
self.0.value.set(value);
is_finished
}
}

#[cfg(target_arch = "wasm32")]
impl<T> Future for LocalTask<T> {
type Output = T;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
match self.0.value.take() {
Some(Ok(value)) => Poll::Ready(value),
None => {
self.0.waker.set(Some(cx.waker().clone()));
Poll::Pending
}
Some(Err(panic)) => std::panic::resume_unwind(panic),
}
}
}

#[cfg(target_arch = "wasm32")]
impl<T> Drop for LocalTask<T> {
fn drop(&mut self) {
let _ = self.0.waker.take();
}
}

#[cfg(target_arch = "wasm32")]
#[pin_project::pin_project]
struct CatchUnwind<F: UnwindSafe>(#[pin] F);

#[cfg(target_arch = "wasm32")]
impl<F: Future + UnwindSafe> Future for CatchUnwind<F> {
type Output = Result<F::Output, Panic>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
std::panic::catch_unwind(AssertUnwindSafe(|| self.project().0.poll(cx)))?.map(Ok)
}
}

/// A `TaskPool` scope for running one or more non-`'static` futures.
Expand Down