Skip to content

Commit

Permalink
sync::watch: Use Acquire/Release memory ordering instead of SeqCst
Browse files Browse the repository at this point in the history
  • Loading branch information
uklotzde committed Sep 19, 2023
1 parent ad7f988 commit 9f1723b
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
use crate::sync::notify::Notify;

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::Relaxed;
use crate::loom::sync::atomic::Ordering;
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
use std::fmt;
use std::mem;
Expand Down Expand Up @@ -247,7 +247,8 @@ struct Shared<T> {

impl<T: fmt::Debug> fmt::Debug for Shared<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.load();
// Using `Relaxed` ordering is sufficient for this purpose.
let state = self.state.load(Ordering::Relaxed);
f.debug_struct("Shared")
.field("value", &self.value)
.field("version", &state.version())
Expand Down Expand Up @@ -341,7 +342,7 @@ mod big_notify {
/// This function implements the case where randomness is not available.
#[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))]
pub(super) fn notified(&self) -> Notified<'_> {
let i = self.next.fetch_add(1, Relaxed) % 8;
let i = self.next.fetch_add(1, Ordering::Relaxed) % 8;
self.inner[i].notified()
}

Expand All @@ -357,7 +358,7 @@ mod big_notify {
use self::state::{AtomicState, Version};
mod state {
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::SeqCst;
use crate::loom::sync::atomic::Ordering;

const CLOSED_BIT: usize = 1;

Expand All @@ -377,6 +378,11 @@ mod state {
pub(super) struct StateSnapshot(usize);

/// The state stored in an atomic integer.
///
/// The `Sender` uses `Release` ordering for storing a new state
/// and the `Receiver`s use `Acquire` ordering for loading the
/// current state. This ensures that written values are seen by
/// the `Receiver`s for a proper handover.
#[derive(Debug)]
pub(super) struct AtomicState(AtomicUsize);

Expand Down Expand Up @@ -412,18 +418,32 @@ mod state {
}

/// Load the current value of the state.
pub(super) fn load(&self) -> StateSnapshot {
StateSnapshot(self.0.load(SeqCst))
pub(super) fn load(&self, ordering: Ordering) -> StateSnapshot {
StateSnapshot(self.0.load(ordering))
}

/// Load the current value of the state.
///
/// The receiver side (read-only) uses `Acquire` ordering for a proper handover
/// with the sender side (single writer).
pub(super) fn load_receiver(&self) -> StateSnapshot {
StateSnapshot(self.0.load(Ordering::Acquire))
}

/// Increment the version counter.
pub(super) fn increment_version(&self) {
self.0.fetch_add(STEP_SIZE, SeqCst);
// Use `Release` ordering to ensure that storing the version
// state is seen by the receiver side that uses `Acquire` for
// loading the state.
self.0.fetch_add(STEP_SIZE, Ordering::Release);
}

/// Set the closed bit in the state.
pub(super) fn set_closed(&self) {
self.0.fetch_or(CLOSED_BIT, SeqCst);
// Use `Release` ordering to ensure that storing the version
// state is seen by the receiver side that uses `Acquire` for
// loading the state.
self.0.fetch_or(CLOSED_BIT, Ordering::Release);
}
}
}
Expand Down Expand Up @@ -489,7 +509,7 @@ impl<T> Receiver<T> {
fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self {
// No synchronization necessary as this is only used as a counter and
// not memory access.
shared.ref_count_rx.fetch_add(1, Relaxed);
shared.ref_count_rx.fetch_add(1, Ordering::Relaxed);

Self { shared, version }
}
Expand Down Expand Up @@ -543,7 +563,7 @@ impl<T> Receiver<T> {

// After obtaining a read-lock no concurrent writes could occur
// and the loaded version matches that of the borrowed reference.
let new_version = self.shared.state.load().version();
let new_version = self.shared.state.load_receiver().version();
let has_changed = self.version != new_version;

Ref { inner, has_changed }
Expand Down Expand Up @@ -590,7 +610,7 @@ impl<T> Receiver<T> {

// After obtaining a read-lock no concurrent writes could occur
// and the loaded version matches that of the borrowed reference.
let new_version = self.shared.state.load().version();
let new_version = self.shared.state.load_receiver().version();
let has_changed = self.version != new_version;

// Mark the shared value as seen by updating the version
Expand Down Expand Up @@ -631,7 +651,7 @@ impl<T> Receiver<T> {
/// ```
pub fn has_changed(&self) -> Result<bool, error::RecvError> {
// Load the version from the state
let state = self.shared.state.load();
let state = self.shared.state.load_receiver();
if state.is_closed() {
// The sender has dropped.
return Err(error::RecvError(()));
Expand Down Expand Up @@ -768,7 +788,7 @@ impl<T> Receiver<T> {
{
let inner = self.shared.value.read().unwrap();

let new_version = self.shared.state.load().version();
let new_version = self.shared.state.load_receiver().version();
let has_changed = self.version != new_version;
self.version = new_version;

Expand Down Expand Up @@ -814,7 +834,7 @@ fn maybe_changed<T>(
version: &mut Version,
) -> Option<Result<(), error::RecvError>> {
// Load the version from the state
let state = shared.state.load();
let state = shared.state.load_receiver();
let new_version = state.version();

if *version != new_version {
Expand Down Expand Up @@ -865,7 +885,7 @@ impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
// No synchronization necessary as this is only used as a counter and
// not memory access.
if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
if 1 == self.shared.ref_count_rx.fetch_sub(1, Ordering::Relaxed) {
// This is the last `Receiver` handle, tasks waiting on `Sender::closed()`
self.shared.notify_tx.notify_waiters();
}
Expand Down Expand Up @@ -1228,7 +1248,7 @@ impl<T> Sender<T> {
/// ```
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
let version = shared.state.load().version();
let version = shared.state.load_receiver().version();

// The CLOSED bit in the state tracks only whether the sender is
// dropped, so we do not need to unset it if this reopens the channel.
Expand All @@ -1254,7 +1274,7 @@ impl<T> Sender<T> {
/// }
/// ```
pub fn receiver_count(&self) -> usize {
self.shared.ref_count_rx.load(Relaxed)
self.shared.ref_count_rx.load(Ordering::Relaxed)
}
}

Expand Down

0 comments on commit 9f1723b

Please sign in to comment.