From 7f8442c36d76a83c1591550a9b20e1258a9f2c83 Mon Sep 17 00:00:00 2001 From: Oleg Nosov Date: Fri, 10 Mar 2023 15:10:45 +0100 Subject: [PATCH] `TryFlattenUnordered`: propagate base stream error (#2607) --- futures-util/benches/flatten_unordered.rs | 15 +-- .../src/stream/stream/flatten_unordered.rs | 80 ++++++++++--- futures-util/src/stream/stream/mod.rs | 2 +- futures-util/src/stream/try_stream/mod.rs | 3 +- .../try_stream/try_flatten_unordered.rs | 111 ++++++++++++------ futures/tests/stream_try_stream.rs | 65 +++++++++- 6 files changed, 209 insertions(+), 67 deletions(-) diff --git a/futures-util/benches/flatten_unordered.rs b/futures-util/benches/flatten_unordered.rs index b92f614914..517b2816c3 100644 --- a/futures-util/benches/flatten_unordered.rs +++ b/futures-util/benches/flatten_unordered.rs @@ -8,6 +8,7 @@ use futures::executor::block_on; use futures::future; use futures::stream::{self, StreamExt}; use futures::task::Poll; +use futures_util::FutureExt; use std::collections::VecDeque; use std::thread; @@ -34,17 +35,9 @@ fn oneshot_streams(b: &mut Bencher) { } }); - let mut flatten = stream::unfold(rxs.into_iter(), |mut vals| { - Box::pin(async { - if let Some(next) = vals.next() { - let val = next.await.unwrap(); - Some((val, vals)) - } else { - None - } - }) - }) - .flatten_unordered(None); + let mut flatten = stream::iter(rxs) + .map(|recv| recv.into_stream().map(|val| val.unwrap()).flatten()) + .flatten_unordered(None); block_on(future::poll_fn(move |cx| { let mut count = 0; diff --git a/futures-util/src/stream/stream/flatten_unordered.rs b/futures-util/src/stream/stream/flatten_unordered.rs index 88006cf235..484c3733aa 100644 --- a/futures-util/src/stream/stream/flatten_unordered.rs +++ b/futures-util/src/stream/stream/flatten_unordered.rs @@ -3,6 +3,7 @@ use core::{ cell::UnsafeCell, convert::identity, fmt, + marker::PhantomData, num::NonZeroUsize, pin::Pin, sync::atomic::{AtomicU8, Ordering}, @@ -22,6 +23,10 @@ use futures_task::{waker, ArcWake}; use crate::stream::FuturesUnordered; +/// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered) +/// method. +pub type FlattenUnordered = FlattenUnorderedWithFlowController; + /// There is nothing to poll and stream isn't being polled/waking/woken at the moment. const NONE: u8 = 0; @@ -154,7 +159,7 @@ impl SharedPollState { /// Resets current state allowing to poll the stream and wake up wakers. fn reset(&self) -> u8 { - self.state.swap(NEED_TO_POLL_ALL, Ordering::AcqRel) + self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst) } } @@ -276,10 +281,10 @@ impl Future for PollStreamFut { pin_project! { /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered) - /// method. - #[project = FlattenUnorderedProj] + /// method with ability to specify flow controller. + #[project = FlattenUnorderedWithFlowControllerProj] #[must_use = "streams do nothing unless polled"] - pub struct FlattenUnordered where St: Stream { + pub struct FlattenUnorderedWithFlowController where St: Stream { #[pin] inner_streams: FuturesUnordered>, #[pin] @@ -289,34 +294,40 @@ pin_project! { is_stream_done: bool, inner_streams_waker: Arc, stream_waker: Arc, + flow_controller: PhantomData } } -impl fmt::Debug for FlattenUnordered +impl fmt::Debug for FlattenUnorderedWithFlowController where St: Stream + fmt::Debug, St::Item: Stream + fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FlattenUnordered") + f.debug_struct("FlattenUnorderedWithFlowController") .field("poll_state", &self.poll_state) .field("inner_streams", &self.inner_streams) .field("limit", &self.limit) .field("stream", &self.stream) .field("is_stream_done", &self.is_stream_done) + .field("flow_controller", &self.flow_controller) .finish() } } -impl FlattenUnordered +impl FlattenUnorderedWithFlowController where St: Stream, + Fc: FlowController::Item>, St::Item: Stream + Unpin, { - pub(super) fn new(stream: St, limit: Option) -> FlattenUnordered { + pub(crate) fn new( + stream: St, + limit: Option, + ) -> FlattenUnorderedWithFlowController { let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM); - FlattenUnordered { + FlattenUnorderedWithFlowController { inner_streams: FuturesUnordered::new(), stream, is_stream_done: false, @@ -332,13 +343,35 @@ where need_to_poll: NEED_TO_POLL_STREAM, }), poll_state, + flow_controller: PhantomData, } } delegate_access_inner!(stream, St, ()); } -impl FlattenUnorderedProj<'_, St> +/// Returns the next flow step based on the received item. +pub trait FlowController { + /// Handles an item producing `FlowStep` describing the next flow step. + fn next_step(item: I) -> FlowStep; +} + +impl FlowController for () { + fn next_step(item: I) -> FlowStep { + FlowStep::Continue(item) + } +} + +/// Describes the next flow step. +#[derive(Debug, Clone)] +pub enum FlowStep { + /// Just yields an item and continues standard flow. + Continue(C), + /// Immediately returns an underlying item from the function. + Return(R), +} + +impl FlattenUnorderedWithFlowControllerProj<'_, St, Fc> where St: Stream, { @@ -348,9 +381,10 @@ where } } -impl FusedStream for FlattenUnordered +impl FusedStream for FlattenUnorderedWithFlowController where St: FusedStream, + Fc: FlowController::Item>, St::Item: Stream + Unpin, { fn is_terminated(&self) -> bool { @@ -358,9 +392,10 @@ where } } -impl Stream for FlattenUnordered +impl Stream for FlattenUnorderedWithFlowController where St: Stream, + Fc: FlowController::Item>, St::Item: Stream + Unpin, { type Item = ::Item; @@ -405,8 +440,23 @@ where let mut cx = Context::from_waker(stream_waker.as_ref().unwrap()); match this.stream.as_mut().poll_next(&mut cx) { - Poll::Ready(Some(inner_stream)) => { - let next_item_fut = PollStreamFut::new(inner_stream); + Poll::Ready(Some(item)) => { + let next_item_fut = match Fc::next_step(item) { + // Propagates an item immediately (the main use-case is for errors) + FlowStep::Return(item) => { + need_to_poll_next |= NEED_TO_POLL_STREAM + | (poll_state_value & NEED_TO_POLL_INNER_STREAMS); + poll_state_value &= !NEED_TO_POLL_INNER_STREAMS; + + next_item = Some(item); + + break; + } + // Yields an item and continues processing (normal case) + FlowStep::Continue(inner_stream) => { + PollStreamFut::new(inner_stream) + } + }; // Add new stream to the inner streams bucket this.inner_streams.as_mut().push(next_item_fut); // Inner streams must be polled afterward @@ -478,7 +528,7 @@ where // Forwarding impl of Sink from the underlying stream #[cfg(feature = "sink")] -impl Sink for FlattenUnordered +impl Sink for FlattenUnorderedWithFlowController where St: Stream + Sink, { diff --git a/futures-util/src/stream/stream/mod.rs b/futures-util/src/stream/stream/mod.rs index eb86cb757d..558dc22bd7 100644 --- a/futures-util/src/stream/stream/mod.rs +++ b/futures-util/src/stream/stream/mod.rs @@ -199,7 +199,7 @@ pub use self::buffered::Buffered; #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] -mod flatten_unordered; +pub(crate) mod flatten_unordered; #[cfg(not(futures_no_atomic_cas))] #[cfg(feature = "alloc")] diff --git a/futures-util/src/stream/try_stream/mod.rs b/futures-util/src/stream/try_stream/mod.rs index 42f5e7324b..414a40dbe3 100644 --- a/futures-util/src/stream/try_stream/mod.rs +++ b/futures-util/src/stream/try_stream/mod.rs @@ -721,7 +721,8 @@ pub trait TryStreamExt: TryStream { } /// Flattens a stream of streams into just one continuous stream. Produced streams - /// will be polled concurrently and any errors are passed through without looking at them. + /// will be polled concurrently and any errors will be passed through without looking at them. + /// If the underlying base stream returns an error, it will be **immediately** propagated. /// /// The only argument is an optional limit on the number of concurrently /// polled streams. If this limit is not `None`, no more than `limit` streams diff --git a/futures-util/src/stream/try_stream/try_flatten_unordered.rs b/futures-util/src/stream/try_stream/try_flatten_unordered.rs index e21b514023..a74dfc451d 100644 --- a/futures-util/src/stream/try_stream/try_flatten_unordered.rs +++ b/futures-util/src/stream/try_stream/try_flatten_unordered.rs @@ -1,3 +1,4 @@ +use core::marker::PhantomData; use core::pin::Pin; use futures_core::ready; @@ -9,19 +10,23 @@ use futures_sink::Sink; use pin_project_lite::pin_project; use crate::future::Either; -use crate::stream::stream::FlattenUnordered; -use crate::StreamExt; - -use super::IntoStream; +use crate::stream::stream::flatten_unordered::{ + FlattenUnorderedWithFlowController, FlowController, FlowStep, +}; +use crate::stream::IntoStream; +use crate::TryStreamExt; delegate_all!( /// Stream for the [`try_flatten_unordered`](super::TryStreamExt::try_flatten_unordered) method. TryFlattenUnordered( - FlattenUnordered> + FlattenUnorderedWithFlowController, PropagateBaseStreamError> ): Debug + Sink + Stream + FusedStream + AccessInner[St, (. .)] + New[ |stream: St, limit: impl Into>| - TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams::new(stream).flatten_unordered(limit) + FlattenUnorderedWithFlowController::new( + NestedTryStreamIntoEitherTryStream::new(stream), + limit.into() + ) ] where St: TryStream, @@ -35,7 +40,7 @@ pin_project! { /// This's a wrapper for `FlattenUnordered` to reuse its logic over `TryStream`. #[derive(Debug)] #[must_use = "streams do nothing unless polled"] - pub struct TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams + pub struct NestedTryStreamIntoEitherTryStream where St: TryStream, St::Ok: TryStream, @@ -43,11 +48,11 @@ pin_project! { ::Error: From { #[pin] - stream: St, + stream: St } } -impl TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +impl NestedTryStreamIntoEitherTryStream where St: TryStream, St::Ok: TryStream + Unpin, @@ -60,21 +65,22 @@ where delegate_access_inner!(stream, St, ()); } -impl FusedStream for TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams -where - St: TryStream + FusedStream, - St::Ok: TryStream + Unpin, - ::Error: From, -{ - fn is_terminated(&self) -> bool { - self.stream.is_terminated() - } -} - -/// Emits single item immediately, then stream will be terminated. +/// Emits a single item immediately, then stream will be terminated. #[derive(Debug, Clone)] pub struct Single(Option); +impl Single { + /// Constructs new `Single` with the given value. + fn new(val: T) -> Self { + Self(Some(val)) + } + + /// Attempts to take inner item immediately. Will always succeed if the stream isn't terminated. + fn next_immediate(&mut self) -> Option { + self.0.take() + } +} + impl Unpin for Single {} impl Stream for Single { @@ -89,9 +95,32 @@ impl Stream for Single { } } +/// Immediately propagates errors occurred in the base stream. +#[derive(Debug, Clone, Copy)] +pub struct PropagateBaseStreamError(PhantomData); + +type BaseStreamItem = as Stream>::Item; +type InnerStreamItem = as Stream>::Item; + +impl FlowController, InnerStreamItem> for PropagateBaseStreamError +where + St: TryStream, + St::Ok: TryStream + Unpin, + ::Error: From, +{ + fn next_step(item: BaseStreamItem) -> FlowStep, InnerStreamItem> { + match item { + // A new successful inner stream received + st @ Either::Left(_) => FlowStep::Continue(st), + // An error encountered + Either::Right(mut err) => FlowStep::Return(err.next_immediate().unwrap()), + } + } +} + type SingleStreamResult = Single::Ok, ::Error>>; -impl Stream for TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +impl Stream for NestedTryStreamIntoEitherTryStream where St: TryStream, St::Ok: TryStream + Unpin, @@ -104,24 +133,38 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let item = ready!(self.project().stream.try_poll_next(cx)); - let out = item.map(|res| match res { - // Emit successful inner stream as is - Ok(stream) => Either::Left(IntoStream::new(stream)), - // Wrap an error into a stream containing a single item - err @ Err(_) => { - let res = err.map(|_: St::Ok| unreachable!()).map_err(Into::into); - - Either::Right(Single(Some(res))) - } - }); + let out = match item { + Some(res) => match res { + // Emit successful inner stream as is + Ok(stream) => Either::Left(stream.into_stream()), + // Wrap an error into a stream containing a single item + err @ Err(_) => { + let res = err.map(|_: St::Ok| unreachable!()).map_err(Into::into); + + Either::Right(Single::new(res)) + } + }, + None => return Poll::Ready(None), + }; + + Poll::Ready(Some(out)) + } +} - Poll::Ready(out) +impl FusedStream for NestedTryStreamIntoEitherTryStream +where + St: TryStream + FusedStream, + St::Ok: TryStream + Unpin, + ::Error: From, +{ + fn is_terminated(&self) -> bool { + self.stream.is_terminated() } } // Forwarding impl of Sink from the underlying stream #[cfg(feature = "sink")] -impl Sink for TryStreamOfTryStreamsIntoHomogeneousStreamOfTryStreams +impl Sink for NestedTryStreamIntoEitherTryStream where St: TryStream + Sink, St::Ok: TryStream + Unpin, diff --git a/futures/tests/stream_try_stream.rs b/futures/tests/stream_try_stream.rs index 6d00097970..b3d04b9200 100644 --- a/futures/tests/stream_try_stream.rs +++ b/futures/tests/stream_try_stream.rs @@ -1,8 +1,12 @@ +use core::pin::Pin; + use futures::{ - stream::{self, StreamExt, TryStreamExt}, + stream::{self, repeat, Repeat, StreamExt, TryStreamExt}, task::Poll, + Stream, }; use futures_executor::block_on; +use futures_task::Context; use futures_test::task::noop_context; #[test] @@ -40,7 +44,7 @@ fn try_take_while_after_err() { #[test] fn try_flatten_unordered() { - let s = stream::iter(1..7) + let test_st = stream::iter(1..7) .map(|val: u32| { if val % 2 == 0 { Ok(stream::unfold((val, 1), |(val, pow)| async move { @@ -61,10 +65,10 @@ fn try_flatten_unordered() { // For all basic evens we must have powers from 1 to 3 vec![ Err(1), - Ok(2), Err(3), - Ok(4), Err(5), + Ok(2), + Ok(4), Ok(6), Ok(4), Err(16), @@ -73,7 +77,58 @@ fn try_flatten_unordered() { Err(64), Ok(216) ], - s.collect::>().await + test_st.collect::>().await ) + }); + + #[derive(Clone, Debug)] + struct ErrorStream { + error_after: usize, + polled: usize, + } + + impl Stream for ErrorStream { + type Item = Result>, ()>; + + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll> { + if self.polled > self.error_after { + panic!("Polled after error"); + } else { + let out = + if self.polled == self.error_after { Err(()) } else { Ok(repeat(Ok(()))) }; + self.polled += 1; + Poll::Ready(Some(out)) + } + } + } + + block_on(async move { + let mut st = ErrorStream { error_after: 3, polled: 0 }.try_flatten_unordered(None); + let mut ctr = 0; + while (st.try_next().await).is_ok() { + ctr += 1; + } + assert_eq!(ctr, 0); + + assert_eq!( + ErrorStream { error_after: 10, polled: 0 } + .try_flatten_unordered(None) + .inspect_ok(|_| panic!("Unexpected `Ok`")) + .try_collect::>() + .await, + Err(()) + ); + + let mut taken = 0; + assert_eq!( + ErrorStream { error_after: 10, polled: 0 } + .map_ok(|st| st.take(3)) + .try_flatten_unordered(1) + .inspect(|_| taken += 1) + .try_fold((), |(), res| async move { Ok(res) }) + .await, + Err(()) + ); + assert_eq!(taken, 31); }) }