From 7de47b06af354f2d831a84151483166f0ae91423 Mon Sep 17 00:00:00 2001 From: Jon Gjengset Date: Tue, 20 Oct 2020 15:35:39 -0700 Subject: [PATCH 1/2] tokio 0.3 --- Cargo.toml | 4 +-- src/combinator.rs | 70 +++++++++++++++++++++++++++++++++-------------- src/lib.rs | 48 ++++++++++++++++---------------- 3 files changed, 76 insertions(+), 46 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b1d55c1..08cfedf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,8 +24,8 @@ maintenance = { status = "passively-maintained" } futures-core = "0.3.0" futures-util = "0.3.0" pin-project = "0.4.0" -tokio = { version = "0.2.0", features = ["sync", "io-util"] } +tokio = { version = "0.3.0", features = ["sync", "io-util"] } [dev-dependencies] futures = "0.3.0" -tokio = { version = "0.2.0", features = ["full"] } +tokio = { version = "0.3.0", features = ["full"] } diff --git a/src/combinator.rs b/src/combinator.rs index 34bfaba..c35e283 100644 --- a/src/combinator.rs +++ b/src/combinator.rs @@ -1,6 +1,7 @@ use crate::Trigger; use futures_core::{ready, stream::Stream}; use pin_project::pin_project; +use std::fmt; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -43,7 +44,7 @@ pub trait StreamExt: Stream { /// let (tx, rx) = tokio::sync::oneshot::channel(); /// /// tokio::spawn(async move { - /// let mut incoming = listener.incoming().take_until_if(rx.map(|_| true)); + /// let mut incoming = listener.take_until_if(rx.map(|_| true)); /// while let Some(mut s) = incoming.next().await.transpose().unwrap() { /// tokio::spawn(async move { /// let (mut r, mut w) = s.split(); @@ -108,38 +109,67 @@ where /// `Tripwire` is internally implemented using a `Shared>`, with the /// `Trigger` holding the associated `oneshot::Sender`. There is very little magic. #[pin_project] -#[derive(Clone, Debug)] -pub struct Tripwire(#[pin] watch::Receiver); +pub struct Tripwire { + watch: watch::Receiver, + + // TODO: existential type + #[pin] + fut: Option + Send>>>, +} + +impl Clone for Tripwire { + fn clone(&self) -> Self { + Self { + watch: self.watch.clone(), + fut: None, + } + } +} + +impl fmt::Debug for Tripwire { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Tripwise").field(&self.watch).finish() + } +} impl Tripwire { /// Make a new `Tripwire` and an associated [`Trigger`]. pub fn new() -> (Trigger, Self) { let (tx, rx) = watch::channel(false); - (Trigger(Some(tx)), Tripwire(rx)) + ( + Trigger(Some(tx)), + Tripwire { + watch: rx, + fut: None, + }, + ) } } impl Future for Tripwire { type Output = bool; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project().0; - loop { - match this.as_mut().poll_recv_ref(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => { - // channel was closed -- we return whatever the latest value was - } - Poll::Ready(Some(v)) if *v => { - // value change to true, and we should exit - return Poll::Ready(true); + let mut this = self.project(); + if this.fut.is_none() { + let mut watch = this.watch.clone(); + this.fut.set(Some(Box::pin(async move { + while !*watch.borrow() { + // value is currently false; wait for it to change + if let Err(_) = watch.changed().await { + // channel was closed -- we return whatever the latest value was + return *watch.borrow(); + } } - Poll::Ready(Some(_)) => { - // value is currently false, we need to poll again - continue; - } - } - return Poll::Ready(*this.borrow()); + // value change to true, and we should exit + true + }))); } + + // Safety: we never move the value inside the option. + // If the Tripwire is pinned, the Option is pinned, and the future inside is as well. + unsafe { this.fut.map_unchecked_mut(|f| f.as_mut().unwrap()) } + .as_mut() + .poll(cx) } } diff --git a/src/lib.rs b/src/lib.rs index 2a58297..111ee78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,11 +19,11 @@ //! //! #[tokio::main] //! async fn main() { -//! let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); +//! let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); //! let (trigger, tripwire) = Tripwire::new(); //! //! tokio::spawn(async move { -//! let mut incoming = listener.incoming().take_until_if(tripwire); +//! let mut incoming = listener.take_until_if(tripwire); //! while let Some(mut s) = incoming.next().await.transpose().unwrap() { //! tokio::spawn(async move { //! let (mut r, mut w) = s.split(); @@ -54,10 +54,10 @@ //! #[tokio::main] //! async fn main() { //! let (exit_tx, exit_rx) = tokio::sync::oneshot::channel(); -//! let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); +//! let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); //! //! tokio::spawn(async move { -//! let (exit, mut incoming) = Valved::new(listener.incoming()); +//! let (exit, mut incoming) = Valved::new(listener); //! exit_tx.send(exit).unwrap(); //! while let Some(mut s) = incoming.next().await.transpose().unwrap() { //! tokio::spawn(async move { @@ -87,12 +87,12 @@ //! #[tokio::main] //! async fn main() { //! let (exit, valve) = Valve::new(); -//! let mut listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); -//! let mut listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); +//! let listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); +//! let listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); //! //! tokio::spawn(async move { -//! let incoming1 = valve.wrap(listener1.incoming()); -//! let incoming2 = valve.wrap(listener2.incoming()); +//! let incoming1 = valve.wrap(listener1); +//! let incoming2 = valve.wrap(listener2); //! //! use futures_util::stream::select; //! let mut incoming = select(incoming1, incoming2); @@ -147,7 +147,7 @@ impl Drop for Trigger { if let Some(tx) = self.0.take() { // Send may fail when all associated rx'es are dropped already // so code here cannot panic on error - let _ = tx.broadcast(true); + let _ = tx.send(true); } } } @@ -163,8 +163,8 @@ mod tests { fn tokio_run() { use std::thread; - let mut rt = tokio::runtime::Runtime::new().unwrap(); - let mut listener = rt + let rt = tokio::runtime::Runtime::new().unwrap(); + let listener = rt .block_on(tokio::net::TcpListener::bind("0.0.0.0:0")) .unwrap(); let (exit_tx, exit_rx) = tokio::sync::oneshot::channel(); @@ -173,7 +173,7 @@ mod tests { // start a tokio echo server rt.block_on(async move { - let (exit, mut incoming) = Valved::new(listener.incoming()); + let (exit, mut incoming) = Valved::new(listener); exit_tx.send(exit).unwrap(); while let Some(mut s) = incoming.next().await.transpose().unwrap() { tokio::spawn(async move { @@ -200,8 +200,8 @@ mod tests { let (exit_tx, exit_rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { - let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); - let (exit, mut incoming) = Valved::new(listener.incoming()); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let (exit, mut incoming) = Valved::new(listener); exit_tx.send(exit).unwrap(); while let Some(mut s) = incoming.next().await.transpose().unwrap() { tokio::spawn(async move { @@ -219,10 +219,10 @@ mod tests { async fn multi_interrupt() { let (exit, valve) = Valve::new(); tokio::spawn(async move { - let mut listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); - let mut listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); - let incoming1 = valve.wrap(listener1.incoming()); - let incoming2 = valve.wrap(listener2.incoming()); + let listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let incoming1 = valve.wrap(listener1); + let incoming2 = valve.wrap(listener2); let mut incoming = select(incoming1, incoming2); while let Some(mut s) = incoming.next().await.transpose().unwrap() { @@ -246,13 +246,13 @@ mod tests { }; let (exit, valve) = Valve::new(); - let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let reqs = Arc::new(AtomicUsize::new(0)); let got = reqs.clone(); tokio::spawn(async move { - let mut incoming = valve.wrap(listener.incoming()); + let mut incoming = valve.wrap(listener); while let Some(mut s) = incoming.next().await.transpose().unwrap() { reqs.fetch_add(1, Ordering::SeqCst); tokio::spawn(async move { @@ -289,8 +289,8 @@ mod tests { }; let (exit, valve) = Valve::new(); - let mut listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); - let mut listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); + let listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap(); let addr1 = listener1.local_addr().unwrap(); let addr2 = listener2.local_addr().unwrap(); @@ -298,8 +298,8 @@ mod tests { let got = reqs.clone(); tokio::spawn(async move { - let incoming1 = valve.wrap(listener1.incoming()); - let incoming2 = valve.wrap(listener2.incoming()); + let incoming1 = valve.wrap(listener1); + let incoming2 = valve.wrap(listener2); let mut incoming = select(incoming1, incoming2); while let Some(mut s) = incoming.next().await.transpose().unwrap() { reqs.fetch_add(1, Ordering::SeqCst); From 3be29bf05b3ca9482a9c620ed994c170c6641a4d Mon Sep 17 00:00:00 2001 From: Jon Gjengset Date: Tue, 20 Oct 2020 15:36:38 -0700 Subject: [PATCH 2/2] pin-project 1.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 08cfedf..3228d03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ maintenance = { status = "passively-maintained" } [dependencies] futures-core = "0.3.0" futures-util = "0.3.0" -pin-project = "0.4.0" +pin-project = "1.0.0" tokio = { version = "0.3.0", features = ["sync", "io-util"] } [dev-dependencies]