Skip to content

Commit

Permalink
Merge pull request #66 from saethlin/master
Browse files Browse the repository at this point in the history
Avoid lossy ptr-int transmutes by using AtomicPtr
  • Loading branch information
yoshuawuyts committed Feb 22, 2024
2 parents 30e562c + b0e5867 commit cb6f0de
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
31 changes: 19 additions & 12 deletions src/native/arc_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
use std::marker;
use std::ops::Deref;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::atomic::{AtomicBool, AtomicPtr};
use std::sync::Arc;

pub struct ArcList<T> {
list: AtomicUsize,
list: AtomicPtr<Node<T>>,
_marker: marker::PhantomData<T>,
}

impl<T> ArcList<T> {
pub fn new() -> ArcList<T> {
ArcList {
list: AtomicUsize::new(0),
list: AtomicPtr::new(Node::EMPTY),
_marker: marker::PhantomData,
}
}
Expand All @@ -31,10 +31,10 @@ impl<T> ArcList<T> {
return Ok(());
}
let mut head = self.list.load(SeqCst);
let node = Arc::into_raw(data.clone()) as usize;
let node = Arc::into_raw(data.clone()) as *mut Node<T>;
loop {
// If we've been sealed off, abort and return an error
if head == 1 {
if head == Node::SEALED {
unsafe {
drop(Arc::from_raw(node as *mut Node<T>));
}
Expand All @@ -55,16 +55,19 @@ impl<T> ArcList<T> {
pub fn take(&self) -> ArcList<T> {
let mut list = self.list.load(SeqCst);
loop {
if list == 1 {
if list == Node::SEALED {
break;
}
match self.list.compare_exchange(list, 0, SeqCst, SeqCst) {
match self
.list
.compare_exchange(list, Node::EMPTY, SeqCst, SeqCst)
{
Ok(_) => break,
Err(l) => list = l,
}
}
ArcList {
list: AtomicUsize::new(list),
list: AtomicPtr::new(list),
_marker: marker::PhantomData,
}
}
Expand All @@ -73,7 +76,7 @@ impl<T> ArcList<T> {
/// `push`.
pub fn take_and_seal(&self) -> ArcList<T> {
ArcList {
list: AtomicUsize::new(self.list.swap(1, SeqCst)),
list: AtomicPtr::new(self.list.swap(Node::SEALED, SeqCst)),
_marker: marker::PhantomData,
}
}
Expand All @@ -82,7 +85,7 @@ impl<T> ArcList<T> {
/// empty list.
pub fn pop(&mut self) -> Option<Arc<Node<T>>> {
let head = *self.list.get_mut();
if head == 0 || head == 1 {
if head == Node::EMPTY || head == Node::SEALED {
return None;
}
let head = unsafe { Arc::from_raw(head as *const Node<T>) };
Expand All @@ -103,15 +106,19 @@ impl<T> Drop for ArcList<T> {
}

pub struct Node<T> {
next: AtomicUsize,
next: AtomicPtr<Node<T>>,
enqueued: AtomicBool,
data: T,
}

impl<T> Node<T> {
const EMPTY: *mut Node<T> = std::ptr::null_mut();

const SEALED: *mut Node<T> = std::ptr::null_mut::<Node<T>>().wrapping_add(1);

pub fn new(data: T) -> Node<T> {
Node {
next: AtomicUsize::new(0),
next: AtomicPtr::new(Node::EMPTY),
enqueued: AtomicBool::new(false),
data,
}
Expand Down
28 changes: 14 additions & 14 deletions src/native/timer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::fmt;
use std::mem;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicPtr, AtomicUsize};
use std::sync::{Arc, Mutex, Weak};
use std::task::{Context, Poll};
use std::time::Instant;
Expand Down Expand Up @@ -216,7 +215,8 @@ impl Default for Timer {
}
}

static HANDLE_FALLBACK: AtomicUsize = AtomicUsize::new(0);
static HANDLE_FALLBACK: AtomicPtr<Inner> = AtomicPtr::new(EMPTY_HANDLE);
const EMPTY_HANDLE: *mut Inner = std::ptr::null_mut();

/// Error returned from `TimerHandle::set_fallback`.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -247,23 +247,23 @@ impl TimerHandle {
/// successful then no future calls may succeed.
fn set_as_global_fallback(self) -> Result<(), SetDefaultError> {
unsafe {
let val = self.into_usize();
match HANDLE_FALLBACK.compare_exchange(0, val, SeqCst, SeqCst) {
let val = self.into_raw();
match HANDLE_FALLBACK.compare_exchange(EMPTY_HANDLE, val, SeqCst, SeqCst) {
Ok(_) => Ok(()),
Err(_) => {
drop(TimerHandle::from_usize(val));
drop(TimerHandle::from_raw(val));
Err(SetDefaultError(()))
}
}
}
}

fn into_usize(self) -> usize {
unsafe { mem::transmute::<Weak<Inner>, usize>(self.inner) }
fn into_raw(self) -> *mut Inner {
self.inner.into_raw() as *mut Inner
}

unsafe fn from_usize(val: usize) -> TimerHandle {
let inner = mem::transmute::<usize, Weak<Inner>>(val);
unsafe fn from_raw(val: *mut Inner) -> TimerHandle {
let inner = Weak::from_raw(val);
TimerHandle { inner }
}
}
Expand All @@ -277,7 +277,7 @@ impl Default for TimerHandle {
// actually create a helper thread then we'll just return a "defunkt"
// handle which will return errors when timer objects are attempted to
// be associated.
if fallback == 0 {
if fallback == EMPTY_HANDLE {
let helper = match global::HelperThread::new() {
Ok(helper) => helper,
Err(_) => return TimerHandle { inner: Weak::new() },
Expand All @@ -301,11 +301,11 @@ impl Default for TimerHandle {
// At this point our fallback handle global was configured so we use
// its value to reify a handle, clone it, and then forget our reified
// handle as we don't actually have an owning reference to it.
assert!(fallback != 0);
assert!(fallback != EMPTY_HANDLE);
unsafe {
let handle = TimerHandle::from_usize(fallback);
let handle = TimerHandle::from_raw(fallback);
let ret = handle.clone();
let _ = handle.into_usize();
let _ = handle.into_raw();
ret
}
}
Expand Down

0 comments on commit cb6f0de

Please sign in to comment.