Skip to content

Commit

Permalink
Use tokio::sync::Mutex instead of std::sync::Mutex so I can revert to…
Browse files Browse the repository at this point in the history
… regular async code
  • Loading branch information
badeend committed Sep 11, 2024
1 parent 96b340b commit b894432
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 48 deletions.
2 changes: 1 addition & 1 deletion crates/wasi/src/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ where
ShutdownType::Send => std::net::Shutdown::Write,
ShutdownType::Both => std::net::Shutdown::Both,
};
Ok(socket.shutdown(how)?)
socket.shutdown(how)
}

fn drop(&mut self, this: Resource<tcp::TcpSocket>) -> Result<(), anyhow::Error> {
Expand Down
18 changes: 5 additions & 13 deletions crates/wasi/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,13 @@ impl<T> AbortOnDropJoinHandle<T> {
/// Abort the task and wait for it to finish. Optionally returns the result
/// of the task if it ran to completion prior to being aborted.
pub(crate) async fn cancel(mut self) -> Option<T> {
std::future::poll_fn(move |cx| self.poll_cancel(cx)).await
}

/// Abort the task and wait for it to finish. Optionally returns the result
/// of the task if it ran to completion prior to being aborted.
pub(crate) fn poll_cancel(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.0.abort();

Poll::Ready(
match std::task::ready!(std::pin::pin!(&mut self.0).poll(cx)) {
Ok(value) => Some(value),
Err(err) if err.is_cancelled() => None,
Err(err) => std::panic::resume_unwind(err.into_panic()),
},
)
match (&mut self.0).await {
Ok(value) => Some(value),
Err(err) if err.is_cancelled() => None,
Err(err) => std::panic::resume_unwind(err.into_panic()),
}
}
}
impl<T> Drop for AbortOnDropJoinHandle<T> {
Expand Down
78 changes: 44 additions & 34 deletions crates/wasi/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::host::network;
use crate::network::SocketAddressFamily;
use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
use crate::{
HostInputStream, HostOutputStream, InputStream, OutputStream, SocketResult, StreamError,
Subscribe,
HostInputStream, HostOutputStream, InputStream, OutputStream, SocketError, SocketResult,
StreamError, Subscribe,
};
use anyhow::Result;
use cap_net_ext::AddressFamily;
Expand All @@ -13,13 +13,13 @@ use io_lifetimes::views::SocketlikeView;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::future::poll_fn;
use std::io;
use std::mem;
use std::net::{Shutdown, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::{Arc, Mutex};
use std::task::{ready, Context, Poll};
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use tokio::sync::Mutex;

/// Value taken from rust std library.
const DEFAULT_BACKLOG: u32 = 128;
Expand Down Expand Up @@ -638,20 +638,17 @@ impl TcpSocket {
Ok(())
}

pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
pub fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
let TcpState::Connected { reader, writer, .. } = &self.tcp_state else {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"socket not connected",
));
return Err(ErrorCode::InvalidState.into());
};

if let Shutdown::Both | Shutdown::Read = how {
reader.lock().unwrap().shutdown();
try_lock_for_socket(reader)?.shutdown();
}

if let Shutdown::Both | Shutdown::Write = how {
writer.lock().unwrap().shutdown();
try_lock_for_socket(writer)?.shutdown();
}

Ok(())
Expand Down Expand Up @@ -739,11 +736,12 @@ impl TcpReader {
self.closed = true;
}

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
async fn ready(&mut self) {
if self.closed {
return Poll::Ready(());
return;
}
self.stream.poll_read_ready(cx).map(|_| ())

self.stream.readable().await.unwrap();
}
}

Expand All @@ -752,14 +750,14 @@ struct TcpReadStream(Arc<Mutex<TcpReader>>);
#[async_trait::async_trait]
impl HostInputStream for TcpReadStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
self.0.lock().unwrap().read(size)
try_lock_for_stream(&self.0)?.read(size)
}
}

#[async_trait::async_trait]
impl Subscribe for TcpReadStream {
async fn ready(&mut self) {
poll_fn(move |cx| self.0.lock().unwrap().poll_ready(cx)).await;
self.0.lock().await.ready().await
}
}

Expand Down Expand Up @@ -923,25 +921,23 @@ impl TcpWriter {
};
}

fn poll_cancel(&mut self, cx: &mut Context<'_>) -> Poll<()> {
match &mut self.state {
WriteState::Writing(task) | WriteState::Closing(task) => {
task.poll_cancel(cx).map(|_| ())
}
_ => Poll::Ready(()),
async fn cancel(&mut self) {
match mem::replace(&mut self.state, WriteState::Closed) {
WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
_ => {}
}
}

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
async fn ready(&mut self) {
match &mut self.state {
WriteState::Writing(task) => {
self.state = match ready!(pin!(task).poll(cx)) {
self.state = match task.await {
Ok(()) => WriteState::Ready,
Err(e) => WriteState::Error(e),
}
}
WriteState::Closing(task) => {
self.state = match ready!(pin!(task).poll(cx)) {
self.state = match task.await {
Ok(()) => WriteState::Closed,
Err(e) => WriteState::Error(e),
}
Expand All @@ -950,9 +946,8 @@ impl TcpWriter {
}

if let WriteState::Ready = self.state {
ready!(self.stream.poll_write_ready(cx)).unwrap();
self.stream.writable().await.unwrap();
}
Poll::Ready(())
}
}

Expand All @@ -961,25 +956,26 @@ struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
#[async_trait::async_trait]
impl HostOutputStream for TcpWriteStream {
fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
self.0.lock().unwrap().write(bytes)
try_lock_for_stream(&self.0)?.write(bytes)
}

fn flush(&mut self) -> Result<(), StreamError> {
self.0.lock().unwrap().flush()
try_lock_for_stream(&self.0)?.flush()
}

fn check_write(&mut self) -> Result<usize, StreamError> {
self.0.lock().unwrap().check_write()
try_lock_for_stream(&self.0)?.check_write()
}

async fn cancel(&mut self) {
poll_fn(move |cx| self.0.lock().unwrap().poll_cancel(cx)).await;
self.0.lock().await.cancel().await
}
}

#[async_trait::async_trait]
impl Subscribe for TcpWriteStream {
async fn ready(&mut self) {
poll_fn(move |cx| self.0.lock().unwrap().poll_ready(cx)).await;
self.0.lock().await.ready().await
}
}

Expand All @@ -988,3 +984,17 @@ fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
.as_socketlike_view::<std::net::TcpStream>()
.shutdown(how);
}

fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
mutex
.try_lock()
.map_err(|_| StreamError::trap("concurrent access to resource not supported"))
}

fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, SocketError> {
mutex.try_lock().map_err(|_| {
SocketError::trap(anyhow::anyhow!(
"concurrent access to resource not supported"
))
})
}

0 comments on commit b894432

Please sign in to comment.