Skip to content

Commit

Permalink
Merge pull request #8 from jonhoo/tokio03
Browse files Browse the repository at this point in the history
Upgrade to tokio 0.3
  • Loading branch information
jonhoo authored Oct 20, 2020
2 parents 4bebd89 + 3be29bf commit 3e86113
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 47 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ maintenance = { status = "passively-maintained" }
[dependencies]
futures-core = "0.3.0"
futures-util = "0.3.0"
pin-project = "0.4.0"
tokio = { version = "0.2.0", features = ["sync", "io-util"] }
pin-project = "1.0.0"
tokio = { version = "0.3.0", features = ["sync", "io-util"] }

[dev-dependencies]
futures = "0.3.0"
tokio = { version = "0.2.0", features = ["full"] }
tokio = { version = "0.3.0", features = ["full"] }
70 changes: 50 additions & 20 deletions src/combinator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::Trigger;
use futures_core::{ready, stream::Stream};
use pin_project::pin_project;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -43,7 +44,7 @@ pub trait StreamExt: Stream {
/// let (tx, rx) = tokio::sync::oneshot::channel();
///
/// tokio::spawn(async move {
/// let mut incoming = listener.incoming().take_until_if(rx.map(|_| true));
/// let mut incoming = listener.take_until_if(rx.map(|_| true));
/// while let Some(mut s) = incoming.next().await.transpose().unwrap() {
/// tokio::spawn(async move {
/// let (mut r, mut w) = s.split();
Expand Down Expand Up @@ -108,38 +109,67 @@ where
/// `Tripwire` is internally implemented using a `Shared<oneshot::Receiver<()>>`, with the
/// `Trigger` holding the associated `oneshot::Sender`. There is very little magic.
#[pin_project]
#[derive(Clone, Debug)]
pub struct Tripwire(#[pin] watch::Receiver<bool>);
pub struct Tripwire {
watch: watch::Receiver<bool>,

// TODO: existential type
#[pin]
fut: Option<Pin<Box<dyn Future<Output = bool> + Send>>>,
}

impl Clone for Tripwire {
fn clone(&self) -> Self {
Self {
watch: self.watch.clone(),
fut: None,
}
}
}

impl fmt::Debug for Tripwire {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Tripwise").field(&self.watch).finish()
}
}

impl Tripwire {
/// Make a new `Tripwire` and an associated [`Trigger`].
pub fn new() -> (Trigger, Self) {
let (tx, rx) = watch::channel(false);
(Trigger(Some(tx)), Tripwire(rx))
(
Trigger(Some(tx)),
Tripwire {
watch: rx,
fut: None,
},
)
}
}

impl Future for Tripwire {
type Output = bool;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project().0;
loop {
match this.as_mut().poll_recv_ref(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
// channel was closed -- we return whatever the latest value was
}
Poll::Ready(Some(v)) if *v => {
// value change to true, and we should exit
return Poll::Ready(true);
let mut this = self.project();
if this.fut.is_none() {
let mut watch = this.watch.clone();
this.fut.set(Some(Box::pin(async move {
while !*watch.borrow() {
// value is currently false; wait for it to change
if let Err(_) = watch.changed().await {
// channel was closed -- we return whatever the latest value was
return *watch.borrow();
}
}
Poll::Ready(Some(_)) => {
// value is currently false, we need to poll again
continue;
}
}
return Poll::Ready(*this.borrow());
// value change to true, and we should exit
true
})));
}

// Safety: we never move the value inside the option.
// If the Tripwire is pinned, the Option is pinned, and the future inside is as well.
unsafe { this.fut.map_unchecked_mut(|f| f.as_mut().unwrap()) }
.as_mut()
.poll(cx)
}
}

Expand Down
48 changes: 24 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
//!
//! #[tokio::main]
//! async fn main() {
//! let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//! let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//! let (trigger, tripwire) = Tripwire::new();
//!
//! tokio::spawn(async move {
//! let mut incoming = listener.incoming().take_until_if(tripwire);
//! let mut incoming = listener.take_until_if(tripwire);
//! while let Some(mut s) = incoming.next().await.transpose().unwrap() {
//! tokio::spawn(async move {
//! let (mut r, mut w) = s.split();
Expand Down Expand Up @@ -54,10 +54,10 @@
//! #[tokio::main]
//! async fn main() {
//! let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
//! let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//! let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//!
//! tokio::spawn(async move {
//! let (exit, mut incoming) = Valved::new(listener.incoming());
//! let (exit, mut incoming) = Valved::new(listener);
//! exit_tx.send(exit).unwrap();
//! while let Some(mut s) = incoming.next().await.transpose().unwrap() {
//! tokio::spawn(async move {
Expand Down Expand Up @@ -87,12 +87,12 @@
//! #[tokio::main]
//! async fn main() {
//! let (exit, valve) = Valve::new();
//! let mut listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//! let mut listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//! let listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//! let listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
//!
//! tokio::spawn(async move {
//! let incoming1 = valve.wrap(listener1.incoming());
//! let incoming2 = valve.wrap(listener2.incoming());
//! let incoming1 = valve.wrap(listener1);
//! let incoming2 = valve.wrap(listener2);
//!
//! use futures_util::stream::select;
//! let mut incoming = select(incoming1, incoming2);
Expand Down Expand Up @@ -147,7 +147,7 @@ impl Drop for Trigger {
if let Some(tx) = self.0.take() {
// Send may fail when all associated rx'es are dropped already
// so code here cannot panic on error
let _ = tx.broadcast(true);
let _ = tx.send(true);
}
}
}
Expand All @@ -163,8 +163,8 @@ mod tests {
fn tokio_run() {
use std::thread;

let mut rt = tokio::runtime::Runtime::new().unwrap();
let mut listener = rt
let rt = tokio::runtime::Runtime::new().unwrap();
let listener = rt
.block_on(tokio::net::TcpListener::bind("0.0.0.0:0"))
.unwrap();
let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();
Expand All @@ -173,7 +173,7 @@ mod tests {

// start a tokio echo server
rt.block_on(async move {
let (exit, mut incoming) = Valved::new(listener.incoming());
let (exit, mut incoming) = Valved::new(listener);
exit_tx.send(exit).unwrap();
while let Some(mut s) = incoming.next().await.transpose().unwrap() {
tokio::spawn(async move {
Expand All @@ -200,8 +200,8 @@ mod tests {
let (exit_tx, exit_rx) = tokio::sync::oneshot::channel();

tokio::spawn(async move {
let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let (exit, mut incoming) = Valved::new(listener.incoming());
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let (exit, mut incoming) = Valved::new(listener);
exit_tx.send(exit).unwrap();
while let Some(mut s) = incoming.next().await.transpose().unwrap() {
tokio::spawn(async move {
Expand All @@ -219,10 +219,10 @@ mod tests {
async fn multi_interrupt() {
let (exit, valve) = Valve::new();
tokio::spawn(async move {
let mut listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let mut listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let incoming1 = valve.wrap(listener1.incoming());
let incoming2 = valve.wrap(listener2.incoming());
let listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let incoming1 = valve.wrap(listener1);
let incoming2 = valve.wrap(listener2);

let mut incoming = select(incoming1, incoming2);
while let Some(mut s) = incoming.next().await.transpose().unwrap() {
Expand All @@ -246,13 +246,13 @@ mod tests {
};

let (exit, valve) = Valve::new();
let mut listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let reqs = Arc::new(AtomicUsize::new(0));
let got = reqs.clone();
tokio::spawn(async move {
let mut incoming = valve.wrap(listener.incoming());
let mut incoming = valve.wrap(listener);
while let Some(mut s) = incoming.next().await.transpose().unwrap() {
reqs.fetch_add(1, Ordering::SeqCst);
tokio::spawn(async move {
Expand Down Expand Up @@ -289,17 +289,17 @@ mod tests {
};

let (exit, valve) = Valve::new();
let mut listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let mut listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let listener1 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let listener2 = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr1 = listener1.local_addr().unwrap();
let addr2 = listener2.local_addr().unwrap();

let reqs = Arc::new(AtomicUsize::new(0));
let got = reqs.clone();

tokio::spawn(async move {
let incoming1 = valve.wrap(listener1.incoming());
let incoming2 = valve.wrap(listener2.incoming());
let incoming1 = valve.wrap(listener1);
let incoming2 = valve.wrap(listener2);
let mut incoming = select(incoming1, incoming2);
while let Some(mut s) = incoming.next().await.transpose().unwrap() {
reqs.fetch_add(1, Ordering::SeqCst);
Expand Down

0 comments on commit 3e86113

Please sign in to comment.