Skip to content

Commit

Permalink
Merge Arcs
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpedda committed Sep 3, 2023
1 parent ce51733 commit 9e2e719
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 57 deletions.
45 changes: 22 additions & 23 deletions src/platform_impl/web/async/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,16 @@ use std::task::Poll;
pub fn channel<T>() -> (AsyncSender<T>, AsyncReceiver<T>) {
let (sender, receiver) = mpsc::channel();
let sender = Arc::new(Mutex::new(sender));
let waker = Arc::new(AtomicWaker::new());
let closed = Arc::new(AtomicBool::new(false));
let inner = Arc::new(Inner {
closed: AtomicBool::new(false),
waker: AtomicWaker::new(),
});

let sender = AsyncSender {
sender,
closed: closed.clone(),
waker: Arc::clone(&waker),
};
let receiver = AsyncReceiver {
receiver,
closed,
waker,
inner: Arc::clone(&inner),
};
let receiver = AsyncReceiver { receiver, inner };

(sender, receiver)
}
Expand All @@ -31,14 +28,13 @@ pub struct AsyncSender<T> {
// to wrap it in an `Arc` to make it clonable on the main thread without
// having to block.
sender: Arc<Mutex<Sender<T>>>,
closed: Arc<AtomicBool>,
waker: Arc<AtomicWaker>,
inner: Arc<Inner>,
}

impl<T> AsyncSender<T> {
pub fn send(&self, event: T) -> Result<(), SendError<T>> {
self.sender.lock().unwrap().send(event)?;
self.waker.wake();
self.inner.waker.wake();

Ok(())
}
Expand All @@ -47,9 +43,8 @@ impl<T> AsyncSender<T> {
impl<T> Clone for AsyncSender<T> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
waker: self.waker.clone(),
closed: self.closed.clone(),
sender: Arc::clone(&self.sender),
inner: Arc::clone(&self.inner),
}
}
}
Expand All @@ -58,34 +53,33 @@ impl<T> Drop for AsyncSender<T> {
fn drop(&mut self) {
// If it's the last + the one held by the receiver make sure to wake it
// up and tell it that all receiver have dropped.
if Arc::strong_count(&self.closed) == 2 {
self.closed.store(true, Ordering::Relaxed);
self.waker.wake()
if Arc::strong_count(&self.inner) == 2 {
self.inner.closed.store(true, Ordering::Relaxed);
self.inner.waker.wake()
}
}
}

pub struct AsyncReceiver<T> {
receiver: Receiver<T>,
closed: Arc<AtomicBool>,
waker: Arc<AtomicWaker>,
inner: Arc<Inner>,
}

impl<T> AsyncReceiver<T> {
pub async fn next(&self) -> Result<T, RecvError> {
future::poll_fn(|cx| match self.receiver.try_recv() {
Ok(event) => Poll::Ready(Ok(event)),
Err(TryRecvError::Empty) => {
if self.closed.load(Ordering::Relaxed) {
if self.inner.closed.load(Ordering::Relaxed) {
return Poll::Ready(Err(RecvError));
}

self.waker.register(cx.waker());
self.inner.waker.register(cx.waker());

match self.receiver.try_recv() {
Ok(event) => Poll::Ready(Ok(event)),
Err(TryRecvError::Empty) => {
if self.closed.load(Ordering::Relaxed) {
if self.inner.closed.load(Ordering::Relaxed) {
Poll::Ready(Err(RecvError))
} else {
Poll::Pending
Expand All @@ -99,3 +93,8 @@ impl<T> AsyncReceiver<T> {
.await
}
}

struct Inner {
closed: AtomicBool,
waker: AtomicWaker,
}
61 changes: 27 additions & 34 deletions src/platform_impl/web/async/waker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use std::task::Poll;

pub struct Waker<T: 'static> {
wrapper: Wrapper<false, Handler<T>, Sender, usize>,
counter: Arc<AtomicUsize>,
waker: Arc<AtomicWaker>,
inner: Arc<Inner>,
}

struct Handler<T> {
Expand All @@ -17,35 +16,29 @@ struct Handler<T> {
}

#[derive(Clone)]
struct Sender {
counter: Arc<AtomicUsize>,
waker: Arc<AtomicWaker>,
closed: Arc<AtomicBool>,
}
struct Sender(Arc<Inner>);

impl Drop for Sender {
fn drop(&mut self) {
if Arc::strong_count(&self.closed) == 1 {
self.closed.store(true, Ordering::Relaxed);
self.waker.wake();
if Arc::strong_count(&self.0) == 1 {
self.0.closed.store(true, Ordering::Relaxed);
self.0.waker.wake();
}
}
}

impl<T> Waker<T> {
#[track_caller]
pub fn new(value: T, handler: fn(&T, usize)) -> Option<Self> {
let counter = Arc::new(AtomicUsize::new(0));
let waker = Arc::new(AtomicWaker::new());
let closed = Arc::new(AtomicBool::new(false));
let inner = Arc::new(Inner {
counter: AtomicUsize::new(0),
waker: AtomicWaker::new(),
closed: AtomicBool::new(false),
});

let handler = Handler { value, handler };

let sender = Sender {
counter: Arc::clone(&counter),
waker: Arc::clone(&waker),
closed: Arc::clone(&closed),
};
let sender = Sender(Arc::clone(&inner));

let wrapper = Wrapper::new(
handler,
Expand All @@ -55,28 +48,27 @@ impl<T> Waker<T> {
(handler.handler)(&handler.value, count);
},
{
let counter = Arc::clone(&counter);
let waker = Arc::clone(&waker);
let inner = Arc::clone(&inner);

move |handler| async move {
while let Some(count) = future::poll_fn(|cx| {
let count = counter.swap(0, Ordering::Relaxed);
let count = inner.counter.swap(0, Ordering::Relaxed);

if count > 0 {
Poll::Ready(Some(count))
} else {
if closed.load(Ordering::Relaxed) {
if inner.closed.load(Ordering::Relaxed) {
return Poll::Ready(None);
}

waker.register(cx.waker());
inner.waker.register(cx.waker());

let count = counter.swap(0, Ordering::Relaxed);
let count = inner.counter.swap(0, Ordering::Relaxed);

if count > 0 {
Poll::Ready(Some(count))
} else {
if closed.load(Ordering::Relaxed) {
if inner.closed.load(Ordering::Relaxed) {
return Poll::Ready(None);
}

Expand All @@ -94,16 +86,12 @@ impl<T> Waker<T> {
},
sender,
|inner, _| {
inner.counter.fetch_add(1, Ordering::Relaxed);
inner.waker.wake();
inner.0.counter.fetch_add(1, Ordering::Relaxed);
inner.0.waker.wake();
},
)?;

Some(Self {
wrapper,
counter,
waker,
})
Some(Self { wrapper, inner })
}

pub fn wake(&self) {
Expand All @@ -115,8 +103,13 @@ impl<T> Clone for Waker<T> {
fn clone(&self) -> Self {
Self {
wrapper: self.wrapper.clone(),
counter: Arc::clone(&self.counter),
waker: Arc::clone(&self.waker),
inner: Arc::clone(&self.inner),
}
}
}

struct Inner {
counter: AtomicUsize,
waker: AtomicWaker,
closed: AtomicBool,
}

0 comments on commit 9e2e719

Please sign in to comment.