Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokio 1.0 compatibility #267

Closed
wants to merge 10 commits into from
Closed
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = [

[features]
default = ["tokio-runtime"]
tokio-runtime = ["tokio/dns", "tokio/macros", "tokio/rt-core", "tokio/tcp", "tokio/rt-threaded", "tokio/time", "reqwest", "serde_bytes"]
tokio-runtime = ["tokio/macros", "tokio/rt", "tokio/net", "tokio/rt-multi-thread", "tokio/time", "reqwest", "serde_bytes"]
async-std-runtime = ["async-std", "async-std/attributes"]
sync = ["async-std-runtime"]

Expand Down Expand Up @@ -65,13 +65,13 @@ version = "0.3.0"
default-features = false

[dependencies.reqwest]
version = "0.10.6"
version = "0.11"
optional = true
default-features = false
features = ["json", "rustls-tls"]

[dependencies.rustls]
version = "0.17.0"
version = "0.19.0"
features = ["dangerous_configuration"]

[dependencies.serde]
Expand All @@ -83,11 +83,11 @@ version = "0.11.5"
optional = true

[dependencies.tokio]
version = "~0.2.18"
version = "1.0.1"
features = ["io-util", "sync", "macros"]

[dependencies.tokio-rustls]
version = "0.13.0"
version = "0.22.0"
features = ["dangerous_configuration"]

[dependencies.uuid]
Expand Down
15 changes: 5 additions & 10 deletions src/cmap/status.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use crate::RUNTIME;

/// Struct used to track the latest status of the pool.
#[derive(Clone, Debug)]
struct PoolStatus {
Expand All @@ -15,11 +13,7 @@ impl Default for PoolStatus {

/// Create a channel for publishing and receiving updates to the pool's generation.
pub(super) fn channel() -> (PoolGenerationPublisher, PoolGenerationSubscriber) {
let (sender, mut receiver) = tokio::sync::watch::channel(Default::default());
// The first call to recv on a watch channel returns immediately with the initial value.
// We use RUNTIME.block_in_place because this is not a truly blocking task, so
// the runtimes don't need to shift things around to ensure scheduling continues normally.
RUNTIME.block_in_place(receiver.recv());
let (sender, receiver) = tokio::sync::watch::channel(Default::default());
(
PoolGenerationPublisher { sender },
PoolGenerationSubscriber { receiver },
Expand All @@ -40,7 +34,7 @@ impl PoolGenerationPublisher {
};

// if nobody is listening, this will return an error, which we don't mind.
let _: std::result::Result<_, _> = self.sender.broadcast(new_status);
let _: std::result::Result<_, _> = self.sender.send(new_status);
}
}

Expand All @@ -62,10 +56,11 @@ impl PoolGenerationSubscriber {
timeout: std::time::Duration,
) -> Option<u32> {
crate::RUNTIME
.timeout(timeout, self.receiver.recv())
.timeout(timeout, self.receiver.changed())
.await
.ok()
.map(|r| r.ok())
.flatten()
.map(|status| status.generation)
.map(|_| self.generation())
}
}
2 changes: 1 addition & 1 deletion src/cmap/test/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use serde::{de::Unexpected, Deserialize, Deserializer};

use crate::{event::cmap::*, options::StreamAddress, RUNTIME};
use tokio::sync::broadcast::{RecvError, SendError};
use tokio::sync::broadcast::error::{RecvError, SendError};

#[derive(Clone, Debug)]
pub struct EventHandler {
Expand Down
4 changes: 2 additions & 2 deletions src/cmap/test/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ async fn concurrent_connections() {
.expect("disabling fail point should succeed");
}

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn connection_error_during_establishment() {
Expand Down Expand Up @@ -209,7 +209,7 @@ async fn connection_error_during_establishment() {
.expect("closed event with error reason should have been seen");
}

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn connection_error_during_operation() {
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ pub enum ErrorKind {
/// A timeout occurred before a Tokio task could be completed.
#[cfg(feature = "tokio-runtime")]
#[error(display = "{}", _0)]
TokioTimeoutElapsed(#[error(source)] tokio::time::Elapsed),
TokioTimeoutElapsed(#[error(source)] tokio::time::error::Elapsed),

#[error(display = "{}", _0)]
RustlsConfig(#[error(source)] rustls::TLSError),
Expand Down
40 changes: 3 additions & 37 deletions src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ impl AsyncRuntime {
{
match self {
#[cfg(feature = "tokio-runtime")]
Self::Tokio => match TokioCallingContext::current() {
TokioCallingContext::Async(handle) => {
Some(AsyncJoinHandle::Tokio(handle.spawn(fut)))
}
TokioCallingContext::Sync => None,
},
Self::Tokio => Some(AsyncJoinHandle::Tokio(tokio::task::spawn(fut))),

#[cfg(feature = "async-std-runtime")]
Self::AsyncStd => Some(AsyncJoinHandle::AsyncStd(async_std::task::spawn(fut))),
Expand Down Expand Up @@ -87,14 +82,7 @@ impl AsyncRuntime {
{
#[cfg(all(feature = "tokio-runtime", not(feature = "async-std-runtime")))]
{
match TokioCallingContext::current() {
TokioCallingContext::Async(_handle) => {
tokio::task::block_in_place(|| futures::executor::block_on(fut))
}
TokioCallingContext::Sync => {
panic!("block_on called from tokio outside of async context")
}
}
tokio::task::block_in_place(|| futures::executor::block_on(fut))
}

#[cfg(feature = "async-std-runtime")]
Expand All @@ -118,7 +106,7 @@ impl AsyncRuntime {
pub(crate) async fn delay_for(self, delay: Duration) {
#[cfg(feature = "tokio-runtime")]
{
tokio::time::delay_for(delay).await
tokio::time::sleep(delay).await
}

#[cfg(feature = "async-std-runtime")]
Expand Down Expand Up @@ -180,25 +168,3 @@ impl AsyncRuntime {
}
}
}

/// Represents the context in which a given runtime method is being called from.
#[cfg(feature = "tokio-runtime")]
enum TokioCallingContext {
/// From a syncronous setting (i.e. not from a runtime thread).
Sync,

/// From an asyncronous setting (i.e. from an async block or function being run on a runtime).
/// Includes a handle to the current runtime.
Async(tokio::runtime::Handle),
}

#[cfg(feature = "tokio-runtime")]
impl TokioCallingContext {
/// Get the current calling context.
fn current() -> Self {
match tokio::runtime::Handle::try_current() {
Ok(handle) => TokioCallingContext::Async(handle),
Err(_) => TokioCallingContext::Sync,
}
}
}
83 changes: 52 additions & 31 deletions src/runtime/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,10 @@ impl From<async_std::net::TcpStream> for AsyncTcpStream {
}

impl AsyncTcpStream {
#[cfg(feature = "tokio-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use tokio::{net::TcpStream, time::timeout};

let stream_future = TcpStream::connect(address);

let stream = if connect_timeout == Duration::from_secs(0) {
stream_future.await?
} else {
timeout(connect_timeout, stream_future).await??
};

stream.set_keepalive(Some(KEEPALIVE_TIME))?;
stream.set_nodelay(true)?;

Ok(stream.into())
}

#[cfg(feature = "async-std-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use async_std::net::TcpStream;
fn try_connect_common(
address: &SocketAddr,
connect_timeout: Duration,
) -> Result<std::net::TcpStream> {
use socket2::{Domain, Protocol, SockAddr, Socket, Type};

let domain = match address {
Expand All @@ -98,10 +81,27 @@ impl AsyncTcpStream {
} else {
socket.connect_timeout(&address, connect_timeout)?;
}
socket.set_nonblocking(true)?;
socket.set_nodelay(true)?;

let stream: TcpStream = socket.into_tcp_stream().into();
stream.set_nodelay(true)?;
Ok(socket.into_tcp_stream())
}

#[cfg(feature = "tokio-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use tokio::net::TcpStream;

let addr_c = address.clone();
let connect_task = tokio::task::spawn_blocking(move || Self::try_connect_common(&addr_c, connect_timeout));
let stream = TcpStream::from_std(connect_task.await.map_err(|e| std::io::Error::from(e))??)?;
Ok(stream.into())
}

#[cfg(feature = "async-std-runtime")]
async fn try_connect(address: &SocketAddr, connect_timeout: Duration) -> Result<Self> {
use async_std::net::TcpStream;

let stream: TcpStream = Self::try_connect_common(address, connect_timeout)?.into();
Ok(stream.into())
}

Expand Down Expand Up @@ -170,7 +170,12 @@ impl AsyncRead for AsyncStream {
match self.deref_mut() {
Self::Null => Poll::Ready(Ok(0)),
Self::Tcp(ref mut inner) => AsyncRead::poll_read(Pin::new(inner), cx, buf),
Self::Tls(ref mut inner) => Pin::new(inner).poll_read(cx, buf),
Self::Tls(ref mut inner) => {
let mut buf = tokio::io::ReadBuf::new(buf);
Pin::new(inner)
.poll_read(cx, &mut buf)
.map_ok(|_| buf.filled().len())
}
}
}
}
Expand Down Expand Up @@ -214,9 +219,11 @@ impl AsyncRead for AsyncTcpStream {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut stream) => {
use tokio::io::AsyncRead;

Pin::new(stream).poll_read(cx, buf)
// Is there a better way to do this?
let mut buf = tokio::io::ReadBuf::new(buf);
Pin::new(stream)
.poll_read(cx, &mut buf)
.map_ok(|_| buf.filled().len())
}

#[cfg(feature = "async-std-runtime")]
Expand Down Expand Up @@ -277,11 +284,25 @@ impl AsyncWrite for AsyncTcpStream {

impl TokioAsyncRead for AsyncTcpStream {
fn poll_read(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<tokio::io::Result<usize>> {
AsyncRead::poll_read(self, cx, buf)
buf: &mut tokio::io::ReadBuf,
) -> Poll<tokio::io::Result<()>> {
match self.deref_mut() {
#[cfg(feature = "tokio-runtime")]
Self::Tokio(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "async-std-runtime")]
Self::AsyncStd(ref mut stream) => {
let s = buf.initialize_unfilled();
let bread = match Pin::new(stream).poll_read(cx, s) {
Poll::Pending => return Poll::Pending,
Poll::Ready(b) => b?,
};

buf.advance(bread);
Poll::Ready(Ok(()))
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async fn select_in_window() {
run_spec_test(&["server-selection", "in_window"], run_test).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn load_balancing_test() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
Expand Down
15 changes: 9 additions & 6 deletions src/sdam/message_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ impl TopologyMessageManager {
/// Requests that the SDAM background tasks check the topology immediately. This should be
/// called by each server selection operation when it fails to select a server.
pub(super) fn request_topology_check(&self) {
let _ = self.topology_check_requester.broadcast(());
let _ = self.topology_check_requester.send(());
}

/// Notifies the server selection operations that the topology has changed. This should be
/// called by SDAM background tasks after a topology check if the topology has changed.
pub(super) fn notify_topology_changed(&self) {
let _ = self.topology_change_notifier.broadcast(());
let _ = self.topology_change_notifier.send(());
}

pub(super) async fn subscribe_to_topology_check_requests(&self) -> TopologyMessageSubscriber {
Expand All @@ -55,14 +55,17 @@ pub(crate) struct TopologyMessageSubscriber {

impl TopologyMessageSubscriber {
async fn new(receiver: &Receiver<()>) -> Self {
let mut receiver = receiver.clone();
receiver.recv().await;
Self { receiver }
Self {
receiver: receiver.clone(),
}
}

/// Waits for either `timeout` to elapse or a message to be received.
/// Returns true if a message was received, false for a timeout.
pub(crate) async fn wait_for_message(&mut self, timeout: Duration) -> bool {
RUNTIME.timeout(timeout, self.receiver.recv()).await.is_ok()
RUNTIME
.timeout(timeout, self.receiver.changed())
.await
.is_ok()
}
}
6 changes: 3 additions & 3 deletions src/sdam/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::{
RUNTIME,
};

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn min_heartbeat_frequency() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
Expand Down Expand Up @@ -83,7 +83,7 @@ async fn min_heartbeat_frequency() {
}

// TODO: RUST-232 update this test to incorporate SDAM events
#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn sdam_pool_management() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
Expand Down Expand Up @@ -151,7 +151,7 @@ async fn sdam_pool_management() {

// prose version of minPoolSize-error.yml SDAM integration test
// TODO: RUST-232 replace this test with the spec runner
#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn sdam_min_pool_size_error() {
let _guard: RwLockWriteGuard<_> = LOCK.run_exclusively().await;
Expand Down
2 changes: 1 addition & 1 deletion src/test/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
RUNTIME,
};

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[function_name::named]
async fn tailable_cursor() {
Expand Down
2 changes: 1 addition & 1 deletion src/test/spec/retryable_reads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use tokio::sync::RwLockWriteGuard;

use crate::test::{run_spec_test, run_v2_test, LOCK};

#[cfg_attr(feature = "tokio-runtime", tokio::test(threaded_scheduler))]
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn run() {
let _guard: RwLockWriteGuard<()> = LOCK.run_exclusively().await;
Expand Down
Loading