Skip to content

Commit

Permalink
Merge pull request #210 from laa/thread_safe_waker
Browse files Browse the repository at this point in the history
Fix of issue #194. All changes in task status are guarded by the check of thread_id.
  • Loading branch information
Glauber Costa authored Dec 8, 2020
2 parents c8eb3bf + 8d2e54c commit 1e444c3
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
5 changes: 5 additions & 0 deletions glommio/src/task/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ use core::task::Waker;
use crate::task::raw::TaskVTable;
use crate::task::state::*;
use crate::task::utils::abort_on_panic;
use std::thread::ThreadId;

/// The header of a task.
///
/// This header is stored right at the beginning of every heap-allocated task.
pub(crate) struct Header {
/// ID of the executor to which task belongs to or in another words by which
/// task was spawned by
pub(crate) thread_id: ThreadId,
/// Current state of the task.
///
/// Contains flags representing the current state and the reference count.
Expand Down Expand Up @@ -84,6 +88,7 @@ impl fmt::Debug for Header {
let state = self.state;

f.debug_struct("Header")
.field("thread_id", &self.thread_id)
.field("scheduled", &(state & SCHEDULED != 0))
.field("running", &(state & RUNNING != 0))
.field("completed", &(state & COMPLETED != 0))
Expand Down
101 changes: 86 additions & 15 deletions glommio/src/task/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use crate::task::header::Header;
use crate::task::state::*;
use crate::task::utils::{abort, abort_on_panic, extend};
use crate::task::Task;
use std::thread::ThreadId;

thread_local!(static THREAD_ID : ThreadId = std::thread::current().id());

/// The vtable for a task.
pub(crate) struct TaskVTable {
Expand Down Expand Up @@ -110,6 +113,7 @@ where

// Write the header as the first field of the task.
(raw.header as *mut Header).write(Header {
thread_id: Self::thread_id(),
state: SCHEDULED | HANDLE | REFERENCE,
awaiter: None,
vtable: &TaskVTable {
Expand All @@ -132,6 +136,12 @@ where
}
}

fn thread_id() -> ThreadId {
THREAD_ID
.try_with(|id| *id)
.unwrap_or_else(|_e| std::thread::current().id())
}

/// Creates a `RawTask` from a raw task pointer.
#[inline]
pub(crate) fn from_ptr(ptr: *const ()) -> Self {
Expand Down Expand Up @@ -183,25 +193,31 @@ where
// we'll do less reference counting if we wake the waker by reference and then drop it.
if mem::size_of::<S>() > 0 {
Self::wake_by_ref(ptr);
Self::drop_waker(ptr);
Self::drop_waker_reference(ptr);
return;
}

let raw = Self::from_ptr(ptr);
assert_eq!(
Self::thread_id(),
(*raw.header).thread_id,
"Waker::wake is called outside of working thread. \
Waker instances can not be moved to or work with multiple threads"
);

let state = (*raw.header).state;

// If the task is completed or closed, it can't be woken up.
if state & (COMPLETED | CLOSED) != 0 {
// Drop the waker.
Self::drop_waker(ptr);
Self::drop_waker_reference(ptr);
return;
}

// If the task is already scheduled do nothing.
if state & SCHEDULED != 0 {
// Drop the waker.
Self::drop_waker(ptr);
Self::drop_waker_reference(ptr);
} else {
// Mark the task as scheduled.
(*(raw.header as *mut Header)).state = state | SCHEDULED;
Expand All @@ -210,7 +226,7 @@ where
Self::schedule(ptr);
} else {
// Drop the waker.
Self::drop_waker(ptr);
Self::drop_waker_reference(ptr);
}
}
}
Expand All @@ -219,6 +235,13 @@ where
unsafe fn wake_by_ref(ptr: *const ()) {
let raw = Self::from_ptr(ptr);

assert_eq!(
Self::thread_id(),
(*raw.header).thread_id,
"Waker::wake_by_ref is called outside of working thread. \
Waker instances can not be moved to or work with multiple threads"
);

let state = (*raw.header).state;

// If the task is completed or closed, it can't be woken up.
Expand Down Expand Up @@ -261,17 +284,27 @@ where
unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
let raw = Self::from_ptr(ptr);

// Increment the reference count. With any kind of reference-counted data structure,
// relaxed ordering is appropriate when incrementing the counter.
let state = (*raw.header).state;
(*(raw.header as *mut Header)).state += REFERENCE;
assert_eq!(
Self::thread_id(),
(*raw.header).thread_id,
"Waker::clone is called outside of working thread. \
Waker instances can not be moved to or work with multiple threads"
);

Self::increment_references(&mut *(raw.header as *mut Header));

RawWaker::new(ptr, &Self::RAW_WAKER_VTABLE)
}

#[inline]
fn increment_references(header: &mut Header) {
let state = header.state;
header.state += REFERENCE;

// If the reference count overflowed, abort.
if state > isize::max_value() as usize {
abort();
}

RawWaker::new(ptr, &Self::RAW_WAKER_VTABLE)
}

/// Drops a waker.
Expand All @@ -281,8 +314,20 @@ where
/// scheduled one more time so that its future gets dropped by the executor.
#[inline]
unsafe fn drop_waker(ptr: *const ()) {
let raw = Self::from_ptr(ptr);
let header = ptr as *const Header;
assert_eq!(
Self::thread_id(),
(*header).thread_id,
"Waker::drop is called outside of working thread. \
Waker instances can not be moved to or work with multiple threads"
);

<RawTask<F, R, S>>::drop_waker_reference(ptr)
}

#[inline]
unsafe fn drop_waker_reference(ptr: *const ()) {
let raw = Self::from_ptr(ptr);
// Decrement the reference count.
let new = (*raw.header).state - REFERENCE;
(*(raw.header as *mut Header)).state = new;
Expand Down Expand Up @@ -328,18 +373,44 @@ where
unsafe fn schedule(ptr: *const ()) {
let raw = Self::from_ptr(ptr);

// If the schedule function has captured variables, create a temporary waker that prevents
// the task from getting deallocated while the function is being invoked.
let _waker;
struct Guard<'a, F, R, S>(&'a RawTask<F, R, S>)
where
F: Future<Output = R> + 'static,
S: Fn(Task) + 'static;

impl<'a, F, R, S> Drop for Guard<'a, F, R, S>
where
F: Future<Output = R> + 'static,
S: Fn(Task) + 'static,
{
fn drop(&mut self) {
let raw = self.0;
let ptr = raw.header as *const ();

unsafe {
RawTask::<F, R, S>::drop_waker_reference(ptr);
}
}
}

let guard;
// Calling of schedule functions itself does not increment references,
// if the schedule function has captured variables, increment references
// so if task being dropped inside schedule function , function itself
// will keep valid data till the end of execution.
if mem::size_of::<S>() > 0 {
_waker = Waker::from_raw(Self::clone_waker(ptr));
Self::increment_references(&mut *(raw.header as *mut Header));
guard = Some(Guard(&raw));
} else {
guard = None;
}

let task = Task {
raw_task: NonNull::new_unchecked(ptr as *mut ()),
};

(*raw.schedule)(task);
drop(guard);
}

/// Drops the future inside a task.
Expand Down

0 comments on commit 1e444c3

Please sign in to comment.