Skip to content

Commit

Permalink
Fix race between closing an observable and a subscriber polling
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Jul 11, 2024
1 parent ac43e87 commit 9517fa4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 48 deletions.
57 changes: 41 additions & 16 deletions eyeball/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
use std::{
hash::{Hash, Hasher},
mem,
sync::{
atomic::{AtomicU64, Ordering},
RwLock,
},
task::Waker,
sync::RwLock,
task::{Context, Poll, Waker},
};

#[derive(Debug)]
pub struct ObservableState<T> {
/// The inner value.
/// The wrapped value.
value: T,

/// The attached observable metadata.
metadata: RwLock<ObservableStateMetadata>,
}

#[derive(Debug)]
struct ObservableStateMetadata {
/// The version of the value.
///
/// Starts at 1 and is incremented by 1 each time the value is updated.
/// When the observable is dropped, this is set to 0 to indicate no further
/// updates will happen.
version: AtomicU64,
version: u64,

/// List of wakers.
///
Expand All @@ -27,12 +30,18 @@ pub struct ObservableState<T> {
/// locked for reading. This way, it is guaranteed that between a subscriber
/// reading the value and adding a waker because the value hasn't changed
/// yet, no updates to the value could have happened.
wakers: RwLock<Vec<Waker>>,
wakers: Vec<Waker>,
}

impl Default for ObservableStateMetadata {
fn default() -> Self {
Self { version: 1, wakers: Vec::new() }
}
}

impl<T> ObservableState<T> {
pub(crate) fn new(value: T) -> Self {
Self { value, version: AtomicU64::new(1), wakers: Default::default() }
Self { value, metadata: Default::default() }
}

/// Get a reference to the inner value.
Expand All @@ -42,11 +51,25 @@ impl<T> ObservableState<T> {

/// Get the current version of the inner value.
pub(crate) fn version(&self) -> u64 {
self.version.load(Ordering::Acquire)
self.metadata.read().unwrap().version
}

pub(crate) fn add_waker(&self, waker: Waker) {
self.wakers.write().unwrap().push(waker);
pub(crate) fn poll_update(
&self,
observed_version: &mut u64,
cx: &Context<'_>,
) -> Poll<Option<()>> {
let mut metadata = self.metadata.write().unwrap();

if metadata.version == 0 {
Poll::Ready(None)
} else if *observed_version < metadata.version {
*observed_version = metadata.version;
Poll::Ready(Some(()))
} else {
metadata.wakers.push(cx.waker().clone());
Poll::Pending
}
}

pub(crate) fn set(&mut self, value: T) -> T {
Expand Down Expand Up @@ -90,14 +113,16 @@ impl<T> ObservableState<T> {

/// "Close" the state – indicate that no further updates will happen.
pub(crate) fn close(&self) {
self.version.store(0, Ordering::Release);
let mut metadata = self.metadata.write().unwrap();
metadata.version = 0;
// Clear the backing buffer for the wakers, no new ones will be added.
wake(mem::take(&mut *self.wakers.write().unwrap()));
wake(mem::take(&mut metadata.wakers));
}

fn incr_version_and_wake(&mut self) {
self.version.fetch_add(1, Ordering::Release);
wake(self.wakers.get_mut().unwrap().drain(..));
let metadata = self.metadata.get_mut().unwrap();
metadata.version += 1;
wake(metadata.wakers.drain(..));
}
}

Expand Down
13 changes: 3 additions & 10 deletions eyeball/src/subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,9 @@ impl<T> Subscriber<T> {

fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll<Option<ObservableReadGuard<'_, T>>> {
let state = self.state.lock();
let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(ObservableReadGuard::new(state)))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
state
.poll_update(&mut self.observed_version, cx)
.map(|ready| ready.map(|_| ObservableReadGuard::new(state)))
}
}

Expand Down
26 changes: 4 additions & 22 deletions eyeball/src/subscriber/async_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,7 @@ impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
fn poll_update(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
let state = ready!(self.state.get_lock.poll(cx));
self.state.get_lock.set(self.state.inner.clone().lock_owned());

let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(()))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
state.poll_update(&mut self.observed_version, cx)
}

fn poll_next_nopin(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>>
Expand All @@ -153,17 +143,9 @@ impl<T: Send + Sync + 'static> Subscriber<T, AsyncLock> {
{
let state = ready!(self.state.get_lock.poll(cx));
self.state.get_lock.set(self.state.inner.clone().lock_owned());

let version = state.version();
if version == 0 {
Poll::Ready(None)
} else if self.observed_version < version {
self.observed_version = version;
Poll::Ready(Some(state.get().clone()))
} else {
state.add_waker(cx.waker().clone());
Poll::Pending
}
state
.poll_update(&mut self.observed_version, cx)
.map(|ready| ready.map(|_| state.get().clone()))
}
}

Expand Down

0 comments on commit 9517fa4

Please sign in to comment.