diff --git a/tokio/src/runtime/time/entry.rs b/tokio/src/runtime/time/entry.rs index 0bd15a74f8b..834077caa3d 100644 --- a/tokio/src/runtime/time/entry.rs +++ b/tokio/src/runtime/time/entry.rs @@ -21,8 +21,9 @@ //! //! Each timer has a state field associated with it. This field contains either //! the current scheduled time, or a special flag value indicating its state. -//! This state can either indicate that the timer is firing (and thus will be fired -//! with an `Ok(())` result soon) or that it has already been fired/deregistered. +//! This state can either indicate that the timer is on the 'pending' queue (and +//! thus will be fired with an `Ok(())` result soon) or that it has already been +//! fired/deregistered. //! //! This single state field allows for code that is firing the timer to //! synchronize with any racing `reset` calls reliably. @@ -48,10 +49,10 @@ //! There is of course a race condition between timer reset and timer //! expiration. If the driver fails to observe the updated expiration time, it //! could trigger expiration of the timer too early. However, because -//! [`mark_firing`][mark_firing] performs a compare-and-swap, it will identify this race and -//! refuse to mark the timer as firing. +//! [`mark_pending`][mark_pending] performs a compare-and-swap, it will identify this race and +//! refuse to mark the timer as pending. //! -//! [mark_firing]: TimerHandle::mark_firing +//! [mark_pending]: TimerHandle::mark_pending use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicU64; @@ -69,9 +70,9 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; type TimerResult = Result<(), crate::time::error::Error>; -pub(super) const STATE_DEREGISTERED: u64 = u64::MAX; -const STATE_FIRING: u64 = STATE_DEREGISTERED - 1; -const STATE_MIN_VALUE: u64 = STATE_FIRING; +const STATE_DEREGISTERED: u64 = u64::MAX; +const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; +const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; /// The largest safe integer to use for ticks. /// /// This value should be updated if any other signal values are added above. @@ -122,6 +123,10 @@ impl StateCell { } } + fn is_pending(&self) -> bool { + self.state.load(Ordering::Relaxed) == STATE_PENDING_FIRE + } + /// Returns the current expiration time, or None if not currently scheduled. fn when(&self) -> Option { let cur_state = self.state.load(Ordering::Relaxed); @@ -157,28 +162,26 @@ impl StateCell { } } - /// Marks this timer firing, if its scheduled time is not after `not_after`. + /// Marks this timer as being moved to the pending list, if its scheduled + /// time is not after `not_after`. /// /// If the timer is scheduled for a time after `not_after`, returns an Err /// containing the current scheduled time. /// /// SAFETY: Must hold the driver lock. - unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { + unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { // Quick initial debug check to see if the timer is already fired. Since // firing the timer can only happen with the driver lock held, we know // we shouldn't be able to "miss" a transition to a fired state, even // with relaxed ordering. let mut cur_state = self.state.load(Ordering::Relaxed); + loop { - // Because its state is STATE_DEREGISTERED, it has been fired. - if cur_state == STATE_DEREGISTERED { - break Err(cur_state); - } // improve the error message for things like // https://github.com/tokio-rs/tokio/issues/3675 assert!( cur_state < STATE_MIN_VALUE, - "mark_firing called when the timer entry is in an invalid state" + "mark_pending called when the timer entry is in an invalid state" ); if cur_state > not_after { @@ -187,7 +190,7 @@ impl StateCell { match self.state.compare_exchange_weak( cur_state, - STATE_FIRING, + STATE_PENDING_FIRE, Ordering::AcqRel, Ordering::Acquire, ) { @@ -334,6 +337,11 @@ pub(crate) struct TimerShared { /// Only accessed under the entry lock. pointers: linked_list::Pointers, + /// The expiration time for which this entry is currently registered. + /// Generally owned by the driver, but is accessed by the entry when not + /// registered. + cached_when: AtomicU64, + /// Current state. This records whether the timer entry is currently under /// the ownership of the driver, and if not, its current state (not /// complete, fired, error, etc). @@ -348,6 +356,7 @@ unsafe impl Sync for TimerShared {} impl std::fmt::Debug for TimerShared { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TimerShared") + .field("cached_when", &self.cached_when.load(Ordering::Relaxed)) .field("state", &self.state) .finish() } @@ -365,12 +374,40 @@ impl TimerShared { pub(super) fn new(shard_id: u32) -> Self { Self { shard_id, + cached_when: AtomicU64::new(0), pointers: linked_list::Pointers::new(), state: StateCell::default(), _p: PhantomPinned, } } + /// Gets the cached time-of-expiration value. + pub(super) fn cached_when(&self) -> u64 { + // Cached-when is only accessed under the driver lock, so we can use relaxed + self.cached_when.load(Ordering::Relaxed) + } + + /// Gets the true time-of-expiration value, and copies it into the cached + /// time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn sync_when(&self) -> u64 { + let true_when = self.true_when(); + + self.cached_when.store(true_when, Ordering::Relaxed); + + true_when + } + + /// Sets the cached time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + unsafe fn set_cached_when(&self, when: u64) { + self.cached_when.store(when, Ordering::Relaxed); + } + /// Returns the true time-of-expiration value, with relaxed memory ordering. pub(super) fn true_when(&self) -> u64 { self.state.when().expect("Timer already fired") @@ -383,6 +420,7 @@ impl TimerShared { /// in the timer wheel. pub(super) unsafe fn set_expiration(&self, t: u64) { self.state.set_expiration(t); + self.cached_when.store(t, Ordering::Relaxed); } /// Sets the true time-of-expiration only if it is after the current. @@ -552,8 +590,16 @@ impl TimerEntry { } impl TimerHandle { - pub(super) unsafe fn true_when(&self) -> u64 { - unsafe { self.inner.as_ref().true_when() } + pub(super) unsafe fn cached_when(&self) -> u64 { + unsafe { self.inner.as_ref().cached_when() } + } + + pub(super) unsafe fn sync_when(&self) -> u64 { + unsafe { self.inner.as_ref().sync_when() } + } + + pub(super) unsafe fn is_pending(&self) -> bool { + unsafe { self.inner.as_ref().state.is_pending() } } /// Forcibly sets the true and cached expiration times to the given tick. @@ -564,7 +610,7 @@ impl TimerHandle { self.inner.as_ref().set_expiration(tick); } - /// Attempts to mark this entry as firing. If the expiration time is after + /// Attempts to mark this entry as pending. If the expiration time is after /// `not_after`, however, returns an Err with the current expiration time. /// /// If an `Err` is returned, the `cached_when` value will be updated to this @@ -572,8 +618,19 @@ impl TimerHandle { /// /// SAFETY: The caller must ensure that the handle remains valid, the driver /// lock is held, and that the timer is not in any wheel linked lists. - pub(super) unsafe fn mark_firing(&self, not_after: u64) -> Result<(), u64> { - self.inner.as_ref().state.mark_firing(not_after) + /// After returning Ok, the entry must be added to the pending list. + pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> { + match self.inner.as_ref().state.mark_pending(not_after) { + Ok(()) => { + // mark this as being on the pending queue in cached_when + self.inner.as_ref().set_cached_when(u64::MAX); + Ok(()) + } + Err(tick) => { + self.inner.as_ref().set_cached_when(tick); + Err(tick) + } + } } /// Attempts to transition to a terminal state. If the state is already a diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 0e4c995dcd0..c01a5f2b25e 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -8,7 +8,7 @@ mod entry; pub(crate) use entry::TimerEntry; -use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION, STATE_DEREGISTERED}; +use entry::{EntryList, TimerHandle, TimerShared, MAX_SAFE_MILLIS_DURATION}; mod handle; pub(crate) use self::handle::Handle; @@ -324,53 +324,23 @@ impl Handle { now = lock.elapsed(); } - while let Some(expiration) = lock.poll(now) { - lock.set_elapsed(expiration.deadline); - // It is critical for `GuardedLinkedList` safety that the guard node is - // pinned in memory and is not dropped until the guarded list is dropped. - let guard = TimerShared::new(id); - pin!(guard); - let guard_handle = guard.as_ref().get_ref().handle(); - - // * This list will be still guarded by the lock of the Wheel with the specefied id. - // `EntryWaitersList` wrapper makes sure we hold the lock to modify it. - // * This wrapper will empty the list on drop. It is critical for safety - // that we will not leave any list entry with a pointer to the local - // guard node after this function returns / panics. - // Safety: The `TimerShared` inside this `TimerHandle` is pinned in the memory. - let mut list = unsafe { lock.get_waiters_list(&expiration, guard_handle, id, self) }; - - while let Some(entry) = list.pop_back_locked(&mut lock) { - let deadline = expiration.deadline; - // Try to expire the entry; this is cheap (doesn't synchronize) if - // the timer is not expired, and updates cached_when. - match unsafe { entry.mark_firing(deadline) } { - Ok(()) => { - // Entry was expired. - // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. - if let Some(waker) = unsafe { entry.fire(Ok(())) } { - waker_list.push(waker); - - if !waker_list.can_push() { - // Wake a batch of wakers. To avoid deadlock, - // we must do this with the lock temporarily dropped. - drop(lock); - waker_list.wake_all(); - - lock = self.inner.lock_sharded_wheel(id); - } - } - } - Err(state) if state == STATE_DEREGISTERED => {} - Err(state) => { - // Safety: This Entry has not expired. - unsafe { lock.reinsert_entry(entry, deadline, state) }; - } + while let Some(entry) = lock.poll(now) { + debug_assert!(unsafe { entry.is_pending() }); + + // SAFETY: We hold the driver lock, and just removed the entry from any linked lists. + if let Some(waker) = unsafe { entry.fire(Ok(())) } { + waker_list.push(waker); + + if !waker_list.can_push() { + // Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped. + drop(lock); + + waker_list.wake_all(); + + lock = self.inner.lock_sharded_wheel(id); } } - lock.occupied_bit_maintain(&expiration); } - let next_wake_up = lock.poll_at(); drop(lock); diff --git a/tokio/src/runtime/time/wheel/level.rs b/tokio/src/runtime/time/wheel/level.rs index 6539f47b4fa..d31eaf46879 100644 --- a/tokio/src/runtime/time/wheel/level.rs +++ b/tokio/src/runtime/time/wheel/level.rs @@ -20,6 +20,7 @@ pub(crate) struct Level { } /// Indicates when a slot must be processed next. +#[derive(Debug)] pub(crate) struct Expiration { /// The level containing the slot. pub(crate) level: usize, @@ -80,7 +81,7 @@ impl Level { // pseudo-ring buffer, and we rotate around them indefinitely. If we // compute a deadline before now, and it's the top level, it // therefore means we're actually looking at a slot in the future. - debug_assert_eq!(self.level, super::MAX_LEVEL_INDEX); + debug_assert_eq!(self.level, super::NUM_LEVELS - 1); deadline += level_range; } @@ -119,7 +120,7 @@ impl Level { } pub(crate) unsafe fn add_entry(&mut self, item: TimerHandle) { - let slot = slot_for(item.true_when(), self.level); + let slot = slot_for(item.cached_when(), self.level); self.slot[slot].push_front(item); @@ -127,25 +128,22 @@ impl Level { } pub(crate) unsafe fn remove_entry(&mut self, item: NonNull) { - let slot = slot_for(unsafe { item.as_ref().true_when() }, self.level); + let slot = slot_for(unsafe { item.as_ref().cached_when() }, self.level); unsafe { self.slot[slot].remove(item) }; if self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + // Unset the bit self.occupied ^= occupied_bit(slot); } } - pub(super) fn take_slot(&mut self, slot: usize) -> EntryList { - std::mem::take(&mut self.slot[slot]) - } + pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList { + self.occupied &= !occupied_bit(slot); - pub(super) fn occupied_bit_maintain(&mut self, slot: usize) { - if self.slot[slot].is_empty() { - self.occupied &= !occupied_bit(slot); - } else { - self.occupied |= occupied_bit(slot); - } + std::mem::take(&mut self.slot[slot]) } } diff --git a/tokio/src/runtime/time/wheel/mod.rs b/tokio/src/runtime/time/wheel/mod.rs index a4034053134..f2b4228514c 100644 --- a/tokio/src/runtime/time/wheel/mod.rs +++ b/tokio/src/runtime/time/wheel/mod.rs @@ -1,6 +1,5 @@ use crate::runtime::time::{TimerHandle, TimerShared}; use crate::time::error::InsertError; -use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; mod level; pub(crate) use self::level::Expiration; @@ -8,59 +7,7 @@ use self::level::Level; use std::{array, ptr::NonNull}; -use super::entry::MAX_SAFE_MILLIS_DURATION; -use super::Handle; - -/// List used in `Handle::process_at_sharded_time`. It wraps a guarded linked list -/// and gates the access to it on the lock of the `Wheel` with the specified `wheel_id`. -/// It also empties the list on drop. -pub(super) struct EntryWaitersList<'a> { - // GuardedLinkedList ensures that the concurrent drop of Entry in this slot is safe. - list: GuardedLinkedList::Target>, - is_empty: bool, - wheel_id: u32, - handle: &'a Handle, -} - -impl<'a> Drop for EntryWaitersList<'a> { - fn drop(&mut self) { - // If the list is not empty, we unlink all waiters from it. - // We do not wake the waiters to avoid double panics. - if !self.is_empty { - let _lock = self.handle.inner.lock_sharded_wheel(self.wheel_id); - while self.list.pop_back().is_some() {} - } - } -} - -impl<'a> EntryWaitersList<'a> { - fn new( - unguarded_list: LinkedList::Target>, - guard_handle: TimerHandle, - wheel_id: u32, - handle: &'a Handle, - ) -> Self { - let list = unguarded_list.into_guarded(guard_handle); - Self { - list, - is_empty: false, - wheel_id, - handle, - } - } - - /// Removes the last element from the guarded list. Modifying this list - /// requires an exclusive access to the Wheel with the specified `wheel_id`. - pub(super) fn pop_back_locked(&mut self, _wheel: &mut Wheel) -> Option { - let result = self.list.pop_back(); - if result.is_none() { - // Save information about emptiness to avoid waiting for lock - // in the destructor. - self.is_empty = true; - } - result - } -} +use super::EntryList; /// Timing wheel implementation. /// @@ -89,6 +36,9 @@ pub(crate) struct Wheel { /// * ~ 4 hr slots / ~ 12 day range /// * ~ 12 day slots / ~ 2 yr range levels: Box<[Level; NUM_LEVELS]>, + + /// Entries queued for firing + pending: EntryList, } /// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots @@ -96,9 +46,6 @@ pub(crate) struct Wheel { /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; -/// The max level index. -pub(super) const MAX_LEVEL_INDEX: usize = NUM_LEVELS - 1; - /// The maximum duration of a `Sleep`. pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; @@ -108,6 +55,7 @@ impl Wheel { Wheel { elapsed: 0, levels: Box::new(array::from_fn(Level::new)), + pending: EntryList::new(), } } @@ -142,7 +90,7 @@ impl Wheel { &mut self, item: TimerHandle, ) -> Result { - let when = item.true_when(); + let when = item.sync_when(); if when <= self.elapsed { return Err((item, InsertError::Elapsed)); @@ -151,7 +99,9 @@ impl Wheel { // Get the level at which the entry should be stored let level = self.level_for(when); - unsafe { self.levels[level].add_entry(item) }; + unsafe { + self.levels[level].add_entry(item); + } debug_assert!({ self.levels[level] @@ -166,8 +116,10 @@ impl Wheel { /// Removes `item` from the timing wheel. pub(crate) unsafe fn remove(&mut self, item: NonNull) { unsafe { - let when = item.as_ref().true_when(); - if when <= MAX_SAFE_MILLIS_DURATION { + let when = item.as_ref().cached_when(); + if when == u64::MAX { + self.pending.remove(item); + } else { debug_assert!( self.elapsed <= when, "elapsed={}; when={}", @@ -176,42 +128,54 @@ impl Wheel { ); let level = self.level_for(when); - // If the entry is not contained in the `slot` list, - // then it is contained by a guarded list. self.levels[level].remove_entry(item); } } } - /// Reinserts `item` to the timing wheel. - /// Safety: This entry must not have expired. - pub(super) unsafe fn reinsert_entry(&mut self, entry: TimerHandle, elapsed: u64, when: u64) { - let level = level_for(elapsed, when); - unsafe { self.levels[level].add_entry(entry) }; - } - /// Instant at which to poll. pub(crate) fn poll_at(&self) -> Option { self.next_expiration().map(|expiration| expiration.deadline) } /// Advances the timer up to the instant represented by `now`. - pub(crate) fn poll(&mut self, now: u64) -> Option { - match self.next_expiration() { - Some(expiration) if expiration.deadline <= now => Some(expiration), - _ => { - // in this case the poll did not indicate an expiration - // _and_ we were not able to find a next expiration in - // the current list of timers. advance to the poll's - // current time and do nothing else. - self.set_elapsed(now); - None + pub(crate) fn poll(&mut self, now: u64) -> Option { + loop { + if let Some(handle) = self.pending.pop_back() { + return Some(handle); + } + + match self.next_expiration() { + Some(ref expiration) if expiration.deadline <= now => { + self.process_expiration(expiration); + + self.set_elapsed(expiration.deadline); + } + _ => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + break; + } } } + + self.pending.pop_back() } /// Returns the instant at which the next timeout expires. fn next_expiration(&self) -> Option { + if !self.pending.is_empty() { + // Expire immediately as we have things pending firing + return Some(Expiration { + level: 0, + slot: 0, + deadline: self.elapsed, + }); + } + // Check all levels for (level_num, level) in self.levels.iter().enumerate() { if let Some(expiration) = level.next_expiration(self.elapsed) { @@ -247,28 +211,11 @@ impl Wheel { res } - pub(super) fn set_elapsed(&mut self, when: u64) { - assert!( - self.elapsed <= when, - "elapsed={:?}; when={:?}", - self.elapsed, - when - ); - - if when > self.elapsed { - self.elapsed = when; - } - } - - /// Obtains the guarded list of entries that need processing for the given expiration. - /// Safety: The `TimerShared` inside `guard_handle` must be pinned in the memory. - pub(super) unsafe fn get_waiters_list<'a>( - &mut self, - expiration: &Expiration, - guard_handle: TimerHandle, - wheel_id: u32, - handle: &'a Handle, - ) -> EntryWaitersList<'a> { + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// queue it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn process_expiration(&mut self, expiration: &Expiration) { // Note that we need to take _all_ of the entries off the list before // processing any of them. This is important because it's possible that // those entries might need to be reinserted into the same slot. @@ -279,12 +226,46 @@ impl Wheel { // they actually need to be dropped down a level. We then reinsert them // back into the same position; we must make sure we don't then process // those entries again or we'll end up in an infinite loop. - let unguarded_list = self.levels[expiration.level].take_slot(expiration.slot); - EntryWaitersList::new(unguarded_list, guard_handle, wheel_id, handle) + let mut entries = self.take_entries(expiration); + + while let Some(item) = entries.pop_back() { + if expiration.level == 0 { + debug_assert_eq!(unsafe { item.cached_when() }, expiration.deadline); + } + + // Try to expire the entry; this is cheap (doesn't synchronize) if + // the timer is not expired, and updates cached_when. + match unsafe { item.mark_pending(expiration.deadline) } { + Ok(()) => { + // Item was expired + self.pending.push_front(item); + } + Err(expiration_tick) => { + let level = level_for(expiration.deadline, expiration_tick); + unsafe { + self.levels[level].add_entry(item); + } + } + } + } + } + + fn set_elapsed(&mut self, when: u64) { + assert!( + self.elapsed <= when, + "elapsed={:?}; when={:?}", + self.elapsed, + when + ); + + if when > self.elapsed { + self.elapsed = when; + } } - pub(super) fn occupied_bit_maintain(&mut self, expiration: &Expiration) { - self.levels[expiration.level].occupied_bit_maintain(expiration.slot); + /// Obtains the list of entries that need processing for the given expiration. + fn take_entries(&mut self, expiration: &Expiration) -> EntryList { + self.levels[expiration.level].take_slot(expiration.slot) } fn level_for(&self, when: u64) -> usize { diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 382d9ee1978..3650f87fbb0 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -334,7 +334,6 @@ feature! { feature = "sync", feature = "rt", feature = "signal", - feature = "time", )] /// An intrusive linked list, but instead of keeping pointers to the head