From 44880d136d35f2dd0a6f1049fd31363bbf334599 Mon Sep 17 00:00:00 2001 From: Leonardo Yvens Date: Tue, 29 Mar 2022 20:25:51 +0100 Subject: [PATCH] util: Fix `call_all` hang when stream is pending (#656) Currently `call_all` will hang in a busy loop if called when the input stream is pending. --- tower/src/util/call_all/common.rs | 26 ++++++++++++-------------- tower/tests/util/call_all.rs | 21 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/tower/src/util/call_all/common.rs b/tower/src/util/call_all/common.rs index 09c458a9b..8e37afc16 100644 --- a/tower/src/util/call_all/common.rs +++ b/tower/src/util/call_all/common.rs @@ -99,20 +99,18 @@ where .expect("Using CallAll after extracing inner Service"); ready!(svc.poll_ready(cx)).map_err(Into::into)?; - // If it is, gather the next request (if there is one) - match this.stream.as_mut().poll_next(cx) { - Poll::Ready(r) => match r { - Some(req) => { - this.queue.push(svc.call(req)); - } - None => { - // We're all done once any outstanding requests have completed - *this.eof = true; - } - }, - Poll::Pending => { - // TODO: We probably want to "release" the slot we reserved in Svc here. - // It may be a while until we get around to actually using it. + // If it is, gather the next request (if there is one), or return `Pending` if the + // stream is not ready. + // TODO: We probably want to "release" the slot we reserved in Svc if the + // stream returns `Pending`. It may be a while until we get around to actually + // using it. + match ready!(this.stream.as_mut().poll_next(cx)) { + Some(req) => { + this.queue.push(svc.call(req)); + } + None => { + // We're all done once any outstanding requests have completed + *this.eof = true; } } } diff --git a/tower/tests/util/call_all.rs b/tower/tests/util/call_all.rs index 6bc092918..02e69ef3d 100644 --- a/tower/tests/util/call_all.rs +++ b/tower/tests/util/call_all.rs @@ -143,3 +143,24 @@ async fn unordered() { .unwrap(); assert!(v.is_none()); } + +#[tokio::test] +async fn pending() { + let _t = support::trace_init(); + + let (mock, mut handle) = mock::pair::<_, &'static str>(); + + let mut task = task::spawn(()); + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let ca = mock.call_all(support::IntoStream::new(rx)); + pin_mut!(ca); + + assert_pending!(task.enter(|cx, _| ca.as_mut().poll_next(cx))); + tx.send("req").unwrap(); + assert_pending!(task.enter(|cx, _| ca.as_mut().poll_next(cx))); + assert_request_eq!(handle, "req").send_response("res"); + let res = assert_ready!(task.enter(|cx, _| ca.as_mut().poll_next(cx))); + assert_eq!(res.transpose().unwrap(), Some("res")); + assert_pending!(task.enter(|cx, _| ca.as_mut().poll_next(cx))); +}