diff --git a/src/stream/stream/flat_map.rs b/src/stream/stream/flat_map.rs index e07893a94..d0cc73d32 100644 --- a/src/stream/stream/flat_map.rs +++ b/src/stream/stream/flat_map.rs @@ -50,14 +50,15 @@ where let mut this = self.project(); loop { if let Some(inner) = this.inner_stream.as_mut().as_pin_mut() { - if let item @ Some(_) = futures_core::ready!(inner.poll_next(cx)) { - return Poll::Ready(item); + match futures_core::ready!(inner.poll_next(cx)) { + item @ Some(_) => return Poll::Ready(item), + None => this.inner_stream.set(None), } } match futures_core::ready!(this.stream.as_mut().poll_next(cx)) { + inner @ Some(_) => this.inner_stream.set(inner.map(IntoStream::into_stream)), None => return Poll::Ready(None), - Some(inner) => this.inner_stream.set(Some(inner.into_stream())), } } } diff --git a/src/stream/stream/flatten.rs b/src/stream/stream/flatten.rs index f2e275c29..e7a498dc2 100644 --- a/src/stream/stream/flatten.rs +++ b/src/stream/stream/flatten.rs @@ -52,14 +52,15 @@ where let mut this = self.project(); loop { if let Some(inner) = this.inner_stream.as_mut().as_pin_mut() { - if let item @ Some(_) = futures_core::ready!(inner.poll_next(cx)) { - return Poll::Ready(item); + match futures_core::ready!(inner.poll_next(cx)) { + item @ Some(_) => return Poll::Ready(item), + None => this.inner_stream.set(None), } } match futures_core::ready!(this.stream.as_mut().poll_next(cx)) { + inner @ Some(_) => this.inner_stream.set(inner.map(IntoStream::into_stream)), None => return Poll::Ready(None), - Some(inner) => this.inner_stream.set(Some(inner.into_stream())), } } } diff --git a/tests/stream.rs b/tests/stream.rs index 3576cb900..3a192339f 100644 --- a/tests/stream.rs +++ b/tests/stream.rs @@ -1,3 +1,5 @@ +use std::convert::identity; +use std::marker::Unpin; use std::pin::Pin; use std::task::{Context, Poll}; @@ -108,3 +110,74 @@ fn merge_works_with_unfused_streams() { assert_eq!(xs, vec![92, 92]); }); } + +struct S(T); + +impl Stream for S { + type Item = T::Item; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + unsafe { Pin::new_unchecked(&mut self.0) }.poll_next(ctx) + } +} + +struct StrictOnce { + polled: bool, +} + +impl Stream for StrictOnce { + type Item = (); + + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll> { + assert!(!self.polled, "Polled after completion!"); + self.polled = true; + Poll::Ready(None) + } +} + +struct Interchanger { + polled: bool, +} + +impl Stream for Interchanger { + type Item = S + Unpin>>; + + fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll> { + if self.polled { + self.polled = false; + ctx.waker().wake_by_ref(); + Poll::Pending + } else { + self.polled = true; + Poll::Ready(Some(S(Box::new(StrictOnce { polled: false })))) + } + } +} + +#[test] +fn flat_map_doesnt_poll_completed_inner_stream() { + task::block_on(async { + assert_eq!( + Interchanger { polled: false } + .take(2) + .flat_map(identity) + .count() + .await, + 0 + ); + }); +} + +#[test] +fn flatten_doesnt_poll_completed_inner_stream() { + task::block_on(async { + assert_eq!( + Interchanger { polled: false } + .take(2) + .flatten() + .count() + .await, + 0 + ); + }); +}