Skip to content

Commit

Permalink
initial tokio 0.3 compatibility
Browse files Browse the repository at this point in the history
src/runtime/stream.rs can be improved maybe

Signed-off-by: rupansh-void <rupanshsekar@hotmail.com>
  • Loading branch information
rupansh committed Nov 12, 2020
1 parent 689ac0a commit ebd0abc
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 82 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = [

[features]
default = ["tokio-runtime"]
tokio-runtime = ["tokio/dns", "tokio/macros", "tokio/rt-core", "tokio/tcp", "tokio/rt-threaded", "tokio/time", "reqwest", "serde_bytes"]
tokio-runtime = ["tokio/macros", "tokio/rt", "tokio/net", "tokio/rt-multi-thread", "tokio/time", "reqwest", "serde_bytes"]
async-std-runtime = ["async-std", "async-std/attributes"]
sync = ["async-std-runtime"]

Expand Down Expand Up @@ -72,7 +72,7 @@ default-features = false
features = ["json", "rustls-tls"]

[dependencies.rustls]
version = "0.17.0"
version = "0.18.1"
features = ["dangerous_configuration"]

[dependencies.serde]
Expand All @@ -84,11 +84,11 @@ version = "0.11.5"
optional = true

[dependencies.tokio]
version = "0.2.18"
version = "0.3.3"
features = ["io-util", "sync", "macros"]

[dependencies.tokio-rustls]
version = "0.13.0"
version = "0.20.0"
features = ["dangerous_configuration"]

[dependencies.uuid]
Expand Down
2 changes: 1 addition & 1 deletion src/cmap/test/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use serde::{de::Unexpected, Deserialize, Deserializer};

use crate::{event::cmap::*, options::StreamAddress, RUNTIME};
use tokio::sync::broadcast::{RecvError, SendError};
use tokio::sync::broadcast::error::{RecvError, SendError};

#[derive(Clone, Debug)]
pub struct EventHandler {
Expand Down
4 changes: 2 additions & 2 deletions src/cmap/test/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async fn concurrent_connections() {
.expect("disabling fail point should succeed");
}

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn connection_error_during_establishment() {
Expand Down Expand Up @@ -190,7 +190,7 @@ async fn connection_error_during_establishment() {
.expect("closed event with error reason should have been seen");
}

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn connection_error_during_operation() {
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ pub enum ErrorKind {
/// A timeout occurred before a Tokio task could be completed.
#[cfg(feature = "tokio-runtime")]
#[error(display = "{}", _0)]
TokioTimeoutElapsed(#[error(source)] tokio::time::Elapsed),
TokioTimeoutElapsed(#[error(source)] tokio::time::error::Elapsed),

#[error(display = "{}", _0)]
RustlsConfig(#[error(source)] rustls::TLSError),
Expand Down
40 changes: 3 additions & 37 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ impl AsyncRuntime {
{
match self {
#[cfg(feature = "tokio-runtime")]
Self::Tokio => match TokioCallingContext::current() {
TokioCallingContext::Async(handle) => {
Some(AsyncJoinHandle::Tokio(handle.spawn(fut)))
}
TokioCallingContext::Sync => None,
},
Self::Tokio => Some(AsyncJoinHandle::Tokio(tokio::task::spawn(fut))),

#[cfg(feature = "async-std-runtime")]
Self::AsyncStd => Some(AsyncJoinHandle::AsyncStd(async_std::task::spawn(fut))),
Expand Down Expand Up @@ -85,14 +80,7 @@ impl AsyncRuntime {
{
#[cfg(all(feature = "tokio-runtime", not(feature = "async-std-runtime")))]
{
match TokioCallingContext::current() {
TokioCallingContext::Async(_handle) => {
tokio::task::block_in_place(|| futures::executor::block_on(fut))
}
TokioCallingContext::Sync => {
panic!("block_on called from tokio outside of async context")
}
}
tokio::task::block_in_place(|| futures::executor::block_on(fut))
}

#[cfg(feature = "async-std-runtime")]
Expand All @@ -105,7 +93,7 @@ impl AsyncRuntime {
pub(crate) async fn delay_for(self, delay: Duration) {
#[cfg(feature = "tokio-runtime")]
{
tokio::time::delay_for(delay).await
tokio::time::sleep(delay).await
}

#[cfg(feature = "async-std-runtime")]
Expand Down Expand Up @@ -167,25 +155,3 @@ impl AsyncRuntime {
}
}
}

/// Represents the context in which a given runtime method is being called from.
#[cfg(feature = "tokio-runtime")]
enum TokioCallingContext {
/// From a syncronous setting (i.e. not from a runtime thread).
Sync,

/// From an asyncronous setting (i.e. from an async block or function being run on a runtime).
/// Includes a handle to the current runtime.
Async(tokio::runtime::Handle),
}

#[cfg(feature = "tokio-runtime")]
impl TokioCallingContext {
/// Get the current calling context.
fn current() -> Self {
match tokio::runtime::Handle::try_current() {
Ok(handle) => TokioCallingContext::Async(handle),
Err(_) => TokioCallingContext::Sync,
}
}
}
71 changes: 41 additions & 30 deletions src/runtime/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,10 @@ impl From<async_std::net::TcpStream> for AsyncTcpStream {
}

impl AsyncTcpStream {
#[cfg(feature = "tokio-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use tokio::{net::TcpStream, time::timeout};

let stream_future = TcpStream::connect(address);

let stream = if connect_timeout == Duration::from_secs(0) {
stream_future.await?
} else {
timeout(connect_timeout, stream_future).await??
};

stream.set_keepalive(Some(KEEPALIVE_TIME))?;
stream.set_nodelay(true)?;

Ok(stream.into())
}

#[cfg(feature = "async-std-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use async_std::net::TcpStream;
fn try_connect_common(
address: &SocketAddr,
connect_timeout: Duration,
) -> Result<std::net::TcpStream> {
use socket2::{Domain, Protocol, SockAddr, Socket, Type};

let domain = match address {
Expand All @@ -98,10 +81,24 @@ impl AsyncTcpStream {
} else {
socket.connect_timeout(&address, connect_timeout)?;
}
socket.set_nodelay(true)?;

Ok(socket.into_tcp_stream())
}

#[cfg(feature = "tokio-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use tokio::net::TcpStream;

let stream: TcpStream = socket.into_tcp_stream().into();
stream.set_nodelay(true)?;
let stream = TcpStream::from_std(Self::try_connect_common(address, connect_timeout)?)?;
Ok(stream.into())
}

#[cfg(feature = "async-std-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use async_std::net::TcpStream;

let stream: TcpStream = Self::try_connect_common(address, connect_timeout)?.into();
Ok(stream.into())
}

Expand Down Expand Up @@ -170,7 +167,12 @@ impl AsyncRead for AsyncStream {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => Pin::new(inner).poll_read(cx, buf),
Self::Tls(ref mut inner) => {
let mut buf = tokio::io::ReadBuf::new(buf);
Pin::new(inner)
.poll_read(cx, &mut buf)
.map_ok(|_| buf.filled().len())
}
}
}
}
Expand Down Expand Up @@ -214,9 +216,11 @@ impl AsyncRead for AsyncTcpStream {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut stream) => {
use tokio::io::AsyncRead;

Pin::new(stream).poll_read(cx, buf)
// Is there a better way to do this?
let mut buf = tokio::io::ReadBuf::new(buf);
Pin::new(stream)
.poll_read(cx, &mut buf)
.map_ok(|_| buf.filled().len())
}

#[cfg(feature = "async-std-runtime")]
Expand Down Expand Up @@ -279,9 +283,16 @@ impl TokioAsyncRead for AsyncTcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<tokio::io::Result<usize>> {
AsyncRead::poll_read(self, cx, buf)
buf: &mut tokio::io::ReadBuf,
) -> Poll<tokio::io::Result<()>> {
let s = buf.initialize_unfilled();
let bread = match AsyncRead::poll_read(self, cx, s) {
Poll::Pending => return Poll::Pending,
Poll::Ready(b) => b?,
};

buf.advance(bread);
Poll::Ready(Ok(()))
}
}

Expand Down
15 changes: 9 additions & 6 deletions src/sdam/message_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ impl TopologyMessageManager {
/// Requests that the SDAM background tasks check the topology immediately. This should be
/// called by each server selection operation when it fails to select a server.
pub(super) fn request_topology_check(&self) {
let _ = self.topology_check_requester.broadcast(());
let _ = self.topology_check_requester.send(());
}

/// Notifies the server selection operations that the topology has changed. This should be
/// called by SDAM background tasks after a topology check if the topology has changed.
pub(super) fn notify_topology_changed(&self) {
let _ = self.topology_change_notifier.broadcast(());
let _ = self.topology_change_notifier.send(());
}

pub(super) async fn subscribe_to_topology_check_requests(&self) -> TopologyMessageSubscriber {
Expand All @@ -55,14 +55,17 @@ pub(crate) struct TopologyMessageSubscriber {

impl TopologyMessageSubscriber {
async fn new(receiver: &Receiver<()>) -> Self {
let mut receiver = receiver.clone();
receiver.recv().await;
Self { receiver }
Self {
receiver: receiver.clone(),
}
}

/// Waits for either `timeout` to elapse or a message to be received.
/// Returns true if a message was received, false for a timeout.
pub(crate) async fn wait_for_message(&mut self, timeout: Duration) -> bool {
RUNTIME.timeout(timeout, self.receiver.recv()).await.is_ok()
RUNTIME
.timeout(timeout, self.receiver.changed())
.await
.is_ok()
}
}
2 changes: 1 addition & 1 deletion src/test/spec/retryable_reads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use tokio::sync::RwLockWriteGuard;

use crate::test::{run_spec_test, run_v2_test, LOCK};

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn run() {
let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;
Expand Down

0 comments on commit ebd0abc

Please sign in to comment.