diff --git a/eyeball/src/state.rs b/eyeball/src/state.rs index ae77a4c..1442b09 100644 --- a/eyeball/src/state.rs +++ b/eyeball/src/state.rs @@ -1,24 +1,27 @@ use std::{ hash::{Hash, Hasher}, mem, - sync::{ - atomic::{AtomicU64, Ordering}, - RwLock, - }, - task::Waker, + sync::RwLock, + task::{Context, Poll, Waker}, }; #[derive(Debug)] pub struct ObservableState { - /// The inner value. + /// The wrapped value. value: T, + /// The attached observable metadata. + metadata: RwLock, +} + +#[derive(Debug)] +struct ObservableStateMetadata { /// The version of the value. /// /// Starts at 1 and is incremented by 1 each time the value is updated. /// When the observable is dropped, this is set to 0 to indicate no further /// updates will happen. - version: AtomicU64, + version: u64, /// List of wakers. /// @@ -27,12 +30,18 @@ pub struct ObservableState { /// locked for reading. This way, it is guaranteed that between a subscriber /// reading the value and adding a waker because the value hasn't changed /// yet, no updates to the value could have happened. - wakers: RwLock>, + wakers: Vec, +} + +impl Default for ObservableStateMetadata { + fn default() -> Self { + Self { version: 1, wakers: Vec::new() } + } } impl ObservableState { pub(crate) fn new(value: T) -> Self { - Self { value, version: AtomicU64::new(1), wakers: Default::default() } + Self { value, metadata: Default::default() } } /// Get a reference to the inner value. @@ -42,11 +51,25 @@ impl ObservableState { /// Get the current version of the inner value. pub(crate) fn version(&self) -> u64 { - self.version.load(Ordering::Acquire) + self.metadata.read().unwrap().version } - pub(crate) fn add_waker(&self, waker: Waker) { - self.wakers.write().unwrap().push(waker); + pub(crate) fn poll_update( + &self, + observed_version: &mut u64, + cx: &Context<'_>, + ) -> Poll> { + let mut metadata = self.metadata.write().unwrap(); + + if metadata.version == 0 { + Poll::Ready(None) + } else if *observed_version < metadata.version { + *observed_version = metadata.version; + Poll::Ready(Some(())) + } else { + metadata.wakers.push(cx.waker().clone()); + Poll::Pending + } } pub(crate) fn set(&mut self, value: T) -> T { @@ -90,14 +113,16 @@ impl ObservableState { /// "Close" the state – indicate that no further updates will happen. pub(crate) fn close(&self) { - self.version.store(0, Ordering::Release); + let mut metadata = self.metadata.write().unwrap(); + metadata.version = 0; // Clear the backing buffer for the wakers, no new ones will be added. - wake(mem::take(&mut *self.wakers.write().unwrap())); + wake(mem::take(&mut metadata.wakers)); } fn incr_version_and_wake(&mut self) { - self.version.fetch_add(1, Ordering::Release); - wake(self.wakers.get_mut().unwrap().drain(..)); + let metadata = self.metadata.get_mut().unwrap(); + metadata.version += 1; + wake(metadata.wakers.drain(..)); } } diff --git a/eyeball/src/subscriber.rs b/eyeball/src/subscriber.rs index ae1bde1..ff2637b 100644 --- a/eyeball/src/subscriber.rs +++ b/eyeball/src/subscriber.rs @@ -122,16 +122,9 @@ impl Subscriber { fn poll_next_ref(&mut self, cx: &Context<'_>) -> Poll>> { let state = self.state.lock(); - let version = state.version(); - if version == 0 { - Poll::Ready(None) - } else if self.observed_version < version { - self.observed_version = version; - Poll::Ready(Some(ObservableReadGuard::new(state))) - } else { - state.add_waker(cx.waker().clone()); - Poll::Pending - } + state + .poll_update(&mut self.observed_version, cx) + .map(|ready| ready.map(|_| ObservableReadGuard::new(state))) } } diff --git a/eyeball/src/subscriber/async_lock.rs b/eyeball/src/subscriber/async_lock.rs index 45b9ea5..6ce9bf2 100644 --- a/eyeball/src/subscriber/async_lock.rs +++ b/eyeball/src/subscriber/async_lock.rs @@ -134,17 +134,7 @@ impl Subscriber { fn poll_update(&mut self, cx: &mut Context<'_>) -> Poll> { let state = ready!(self.state.get_lock.poll(cx)); self.state.get_lock.set(self.state.inner.clone().lock_owned()); - - let version = state.version(); - if version == 0 { - Poll::Ready(None) - } else if self.observed_version < version { - self.observed_version = version; - Poll::Ready(Some(())) - } else { - state.add_waker(cx.waker().clone()); - Poll::Pending - } + state.poll_update(&mut self.observed_version, cx) } fn poll_next_nopin(&mut self, cx: &mut Context<'_>) -> Poll> @@ -153,17 +143,9 @@ impl Subscriber { { let state = ready!(self.state.get_lock.poll(cx)); self.state.get_lock.set(self.state.inner.clone().lock_owned()); - - let version = state.version(); - if version == 0 { - Poll::Ready(None) - } else if self.observed_version < version { - self.observed_version = version; - Poll::Ready(Some(state.get().clone())) - } else { - state.add_waker(cx.waker().clone()); - Poll::Pending - } + state + .poll_update(&mut self.observed_version, cx) + .map(|ready| ready.map(|_| state.get().clone())) } }