Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: unify entering a runtime with Handle::enter #5163

Merged
merged 2 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tokio/src/future/block_on.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ use std::future::Future;
cfg_rt! {
#[track_caller]
pub(crate) fn block_on<F: Future>(f: F) -> F::Output {
let mut e = crate::runtime::enter::enter(false);
let mut e = crate::runtime::enter::try_enter_blocking_region().expect(
"Cannot block the current thread from within a runtime. This \
happens because a functionattempted to block the current \
thread while the thread is being used to drive asynchronous \
tasks."
);
e.block_on(f).unwrap()
}
}
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/runtime/blocking/shutdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ impl Receiver {
///
/// If the timeout has elapsed, it returns `false`, otherwise it returns `true`.
pub(crate) fn wait(&mut self, timeout: Option<Duration>) -> bool {
use crate::runtime::enter::try_enter;
use crate::runtime::enter::try_enter_blocking_region;

if timeout == Some(Duration::from_nanos(0)) {
return false;
}

let mut e = match try_enter(false) {
let mut e = match try_enter_blocking_region() {
Some(enter) => enter,
_ => {
if std::thread::panicking() {
Expand Down
54 changes: 41 additions & 13 deletions tokio/src/runtime/enter.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use crate::runtime::scheduler;

use std::cell::{Cell, RefCell};
use std::fmt;
use std::marker::PhantomData;

#[derive(Debug, Clone, Copy)]
pub(crate) enum EnterContext {
/// Currently in a runtime context.
#[cfg_attr(not(feature = "rt"), allow(dead_code))]
Entered {
allow_block_in_place: bool,
},
Entered { allow_block_in_place: bool },

/// Not in a runtime context **or** a blocking region.
NotEntered,
}

Expand All @@ -19,19 +22,29 @@ impl EnterContext {

tokio_thread_local!(static ENTERED: Cell<EnterContext> = const { Cell::new(EnterContext::NotEntered) });

/// Represents an executor context.
pub(crate) struct Enter {
/// Guard tracking that a caller has entered a runtime context.
pub(crate) struct EnterRuntimeGuard {
pub(crate) blocking: BlockingRegionGuard,
}

/// Guard tracking that a caller has entered a blocking region.
pub(crate) struct BlockingRegionGuard {
_p: PhantomData<RefCell<()>>,
}

cfg_rt! {
use crate::runtime::context;

use std::time::Duration;

/// Marks the current thread as being within the dynamic extent of an
/// executor.
#[track_caller]
pub(crate) fn enter(allow_block_in_place: bool) -> Enter {
if let Some(enter) = try_enter(allow_block_in_place) {
pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard {
if let Some(enter) = try_enter_runtime(allow_block_in_place) {
// Set the current runtime handle. This should not fail. A later
// cleanup will remove the unwrap().
context::try_set_current(handle).unwrap();
return enter;
}

Expand All @@ -45,13 +58,25 @@ cfg_rt! {

/// Tries to enter a runtime context, returns `None` if already in a runtime
/// context.
pub(crate) fn try_enter(allow_block_in_place: bool) -> Option<Enter> {
fn try_enter_runtime(allow_block_in_place: bool) -> Option<EnterRuntimeGuard> {
ENTERED.with(|c| {
if c.get().is_entered() {
None
} else {
c.set(EnterContext::Entered { allow_block_in_place });
Some(Enter { _p: PhantomData })
Some(EnterRuntimeGuard {
blocking: BlockingRegionGuard::new(),
})
}
})
}

pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> {
ENTERED.with(|c| {
if c.get().is_entered() {
None
} else {
Some(BlockingRegionGuard::new())
}
})
}
Expand All @@ -65,7 +90,7 @@ cfg_rt! {
// This is hidden for a reason. Do not use without fully understanding
// executors. Misusing can easily cause your program to deadlock.
cfg_rt_multi_thread! {
pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R {
pub(crate) fn exit_runtime<F: FnOnce() -> R, R>(f: F) -> R {
// Reset in case the closure panics
struct Reset(EnterContext);
impl Drop for Reset {
Expand Down Expand Up @@ -139,7 +164,10 @@ cfg_rt_multi_thread! {
cfg_rt! {
use crate::loom::thread::AccessError;

impl Enter {
impl BlockingRegionGuard {
fn new() -> BlockingRegionGuard {
BlockingRegionGuard { _p: PhantomData }
}
/// Blocks the thread on the specified future, returning the value with
/// which that future completes.
pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, AccessError>
Expand Down Expand Up @@ -189,13 +217,13 @@ cfg_rt! {
}
}

impl fmt::Debug for Enter {
impl fmt::Debug for EnterRuntimeGuard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Enter").finish()
}
}

impl Drop for Enter {
impl Drop for EnterRuntimeGuard {
fn drop(&mut self) {
ENTERED.with(|c| {
assert!(c.get().is_entered());
Expand Down
11 changes: 5 additions & 6 deletions tokio/src/runtime/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,13 @@ impl Handle {
let future =
crate::util::trace::task(future, "block_on", None, super::task::Id::next().as_u64());

// Enter the **runtime** context. This configures spawning, the current I/O driver, ...
let _rt_enter = self.enter();

// Enter a **blocking** context. This prevents blocking from a runtime.
let mut blocking_enter = crate::runtime::enter(true);
// Enter the runtime context. This sets the current driver handles and
// prevents blocking an existing runtime.
let mut enter = crate::runtime::enter::enter_runtime(&self.inner, true);

// Block on the future
blocking_enter
enter
.blocking
.block_on(future)
.expect("failed to park thread")
}
Expand Down
2 changes: 1 addition & 1 deletion tokio/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ cfg_rt! {
pub use crate::util::rand::RngSeed;
}

use self::enter::enter;
use self::enter::enter_runtime;

mod handle;
pub use handle::{EnterGuard, Handle, TryCurrentError};
Expand Down
9 changes: 6 additions & 3 deletions tokio/src/runtime/scheduler/current_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,25 @@ impl CurrentThread {

#[track_caller]
pub(crate) fn block_on<F: Future>(&self, future: F) -> F::Output {
use crate::runtime::scheduler;

pin!(future);

let handle = scheduler::Handle::CurrentThread(self.handle.clone());
let mut enter = crate::runtime::enter_runtime(&handle, false);

// Attempt to steal the scheduler core and block_on the future if we can
// there, otherwise, lets select on a notification that the core is
// available or the future is complete.
loop {
if let Some(core) = self.take_core() {
return core.block_on(future);
} else {
let mut enter = crate::runtime::enter(false);

let notified = self.notify.notified();
pin!(notified);

if let Some(out) = enter
.blocking
.block_on(poll_fn(|cx| {
if notified.as_mut().poll(cx).is_ready() {
return Ready(None);
Expand Down Expand Up @@ -522,7 +526,6 @@ impl CoreGuard<'_> {
#[track_caller]
fn block_on<F: Future>(self, future: F) -> F::Output {
let ret = self.enter(|mut core, context| {
let _enter = crate::runtime::enter(false);
let waker = Handle::waker_ref(&context.handle);
let mut cx = std::task::Context::from_waker(&waker);

Expand Down
10 changes: 7 additions & 3 deletions tokio/src/runtime/scheduler/multi_thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::loom::sync::Arc;
use crate::runtime::{
blocking,
driver::{self, Driver},
Config,
scheduler, Config,
};
use crate::util::RngSeedGenerator;

Expand Down Expand Up @@ -73,8 +73,12 @@ impl MultiThread {
where
F: Future,
{
let mut enter = crate::runtime::enter(true);
enter.block_on(future).expect("failed to park thread")
let handle = scheduler::Handle::MultiThread(self.handle.clone());
let mut enter = crate::runtime::enter_runtime(&handle, true);
enter
.blocking
.block_on(future)
.expect("failed to park thread")
}
}

Expand Down
9 changes: 5 additions & 4 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use crate::runtime::enter::EnterContext;
use crate::runtime::scheduler::multi_thread::{queue, Handle, Idle, Parker, Unparker};
use crate::runtime::task::{Inject, OwnedTasks};
use crate::runtime::{
blocking, coop, driver, task, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics,
blocking, coop, driver, scheduler, task, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics,
};
use crate::util::atomic_cell::AtomicCell;
use crate::util::rand::{FastRand, RngSeedGenerator};
Expand Down Expand Up @@ -350,7 +350,7 @@ where
// constrained by task budgets.
let _reset = Reset(coop::stop());

crate::runtime::enter::exit(f)
crate::runtime::enter::exit_runtime(f)
} else {
f()
}
Expand All @@ -372,14 +372,15 @@ fn run(worker: Arc<Worker>) {
None => return,
};

let handle = scheduler::Handle::MultiThread(worker.handle.clone());
let _enter = crate::runtime::enter_runtime(&handle, true);

// Set the worker context.
let cx = Context {
worker,
core: RefCell::new(None),
};

let _enter = crate::runtime::enter(true);

CURRENT.set(&cx, || {
// This should always be an error. It only returns a `Result` to support
// using `?` to short circuit.
Expand Down