From b8944321bbaf99c8b67f415bf34c7c5ec25351f3 Mon Sep 17 00:00:00 2001 From: badeend Date: Wed, 11 Sep 2024 21:48:08 +0200 Subject: [PATCH] Use tokio::sync::Mutex instead of std::sync::Mutex so I can revert to regular async code --- crates/wasi/src/host/tcp.rs | 2 +- crates/wasi/src/runtime.rs | 18 +++------ crates/wasi/src/tcp.rs | 78 +++++++++++++++++++++---------------- 3 files changed, 50 insertions(+), 48 deletions(-) diff --git a/crates/wasi/src/host/tcp.rs b/crates/wasi/src/host/tcp.rs index 2f20cdc2590e..143ea607eb1c 100644 --- a/crates/wasi/src/host/tcp.rs +++ b/crates/wasi/src/host/tcp.rs @@ -298,7 +298,7 @@ where ShutdownType::Send => std::net::Shutdown::Write, ShutdownType::Both => std::net::Shutdown::Both, }; - Ok(socket.shutdown(how)?) + socket.shutdown(how) } fn drop(&mut self, this: Resource) -> Result<(), anyhow::Error> { diff --git a/crates/wasi/src/runtime.rs b/crates/wasi/src/runtime.rs index be428803a061..b1fb7839618f 100644 --- a/crates/wasi/src/runtime.rs +++ b/crates/wasi/src/runtime.rs @@ -43,21 +43,13 @@ impl AbortOnDropJoinHandle { /// Abort the task and wait for it to finish. Optionally returns the result /// of the task if it ran to completion prior to being aborted. pub(crate) async fn cancel(mut self) -> Option { - std::future::poll_fn(move |cx| self.poll_cancel(cx)).await - } - - /// Abort the task and wait for it to finish. Optionally returns the result - /// of the task if it ran to completion prior to being aborted. - pub(crate) fn poll_cancel(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.abort(); - Poll::Ready( - match std::task::ready!(std::pin::pin!(&mut self.0).poll(cx)) { - Ok(value) => Some(value), - Err(err) if err.is_cancelled() => None, - Err(err) => std::panic::resume_unwind(err.into_panic()), - }, - ) + match (&mut self.0).await { + Ok(value) => Some(value), + Err(err) if err.is_cancelled() => None, + Err(err) => std::panic::resume_unwind(err.into_panic()), + } } } impl Drop for AbortOnDropJoinHandle { diff --git a/crates/wasi/src/tcp.rs b/crates/wasi/src/tcp.rs index 068e68f4dafa..44ba68416c07 100644 --- a/crates/wasi/src/tcp.rs +++ b/crates/wasi/src/tcp.rs @@ -3,8 +3,8 @@ use crate::host::network; use crate::network::SocketAddressFamily; use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle}; use crate::{ - HostInputStream, HostOutputStream, InputStream, OutputStream, SocketResult, StreamError, - Subscribe, + HostInputStream, HostOutputStream, InputStream, OutputStream, SocketError, SocketResult, + StreamError, Subscribe, }; use anyhow::Result; use cap_net_ext::AddressFamily; @@ -13,13 +13,13 @@ use io_lifetimes::views::SocketlikeView; use io_lifetimes::AsSocketlike; use rustix::io::Errno; use rustix::net::sockopt; -use std::future::poll_fn; use std::io; use std::mem; use std::net::{Shutdown, SocketAddr}; -use std::pin::{pin, Pin}; -use std::sync::{Arc, Mutex}; -use std::task::{ready, Context, Poll}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; +use tokio::sync::Mutex; /// Value taken from rust std library. const DEFAULT_BACKLOG: u32 = 128; @@ -638,20 +638,17 @@ impl TcpSocket { Ok(()) } - pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + pub fn shutdown(&self, how: Shutdown) -> SocketResult<()> { let TcpState::Connected { reader, writer, .. } = &self.tcp_state else { - return Err(io::Error::new( - io::ErrorKind::NotConnected, - "socket not connected", - )); + return Err(ErrorCode::InvalidState.into()); }; if let Shutdown::Both | Shutdown::Read = how { - reader.lock().unwrap().shutdown(); + try_lock_for_socket(reader)?.shutdown(); } if let Shutdown::Both | Shutdown::Write = how { - writer.lock().unwrap().shutdown(); + try_lock_for_socket(writer)?.shutdown(); } Ok(()) @@ -739,11 +736,12 @@ impl TcpReader { self.closed = true; } - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + async fn ready(&mut self) { if self.closed { - return Poll::Ready(()); + return; } - self.stream.poll_read_ready(cx).map(|_| ()) + + self.stream.readable().await.unwrap(); } } @@ -752,14 +750,14 @@ struct TcpReadStream(Arc>); #[async_trait::async_trait] impl HostInputStream for TcpReadStream { fn read(&mut self, size: usize) -> Result { - self.0.lock().unwrap().read(size) + try_lock_for_stream(&self.0)?.read(size) } } #[async_trait::async_trait] impl Subscribe for TcpReadStream { async fn ready(&mut self) { - poll_fn(move |cx| self.0.lock().unwrap().poll_ready(cx)).await; + self.0.lock().await.ready().await } } @@ -923,25 +921,23 @@ impl TcpWriter { }; } - fn poll_cancel(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match &mut self.state { - WriteState::Writing(task) | WriteState::Closing(task) => { - task.poll_cancel(cx).map(|_| ()) - } - _ => Poll::Ready(()), + async fn cancel(&mut self) { + match mem::replace(&mut self.state, WriteState::Closed) { + WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await, + _ => {} } } - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + async fn ready(&mut self) { match &mut self.state { WriteState::Writing(task) => { - self.state = match ready!(pin!(task).poll(cx)) { + self.state = match task.await { Ok(()) => WriteState::Ready, Err(e) => WriteState::Error(e), } } WriteState::Closing(task) => { - self.state = match ready!(pin!(task).poll(cx)) { + self.state = match task.await { Ok(()) => WriteState::Closed, Err(e) => WriteState::Error(e), } @@ -950,9 +946,8 @@ impl TcpWriter { } if let WriteState::Ready = self.state { - ready!(self.stream.poll_write_ready(cx)).unwrap(); + self.stream.writable().await.unwrap(); } - Poll::Ready(()) } } @@ -961,25 +956,26 @@ struct TcpWriteStream(Arc>); #[async_trait::async_trait] impl HostOutputStream for TcpWriteStream { fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> { - self.0.lock().unwrap().write(bytes) + try_lock_for_stream(&self.0)?.write(bytes) } fn flush(&mut self) -> Result<(), StreamError> { - self.0.lock().unwrap().flush() + try_lock_for_stream(&self.0)?.flush() } fn check_write(&mut self) -> Result { - self.0.lock().unwrap().check_write() + try_lock_for_stream(&self.0)?.check_write() } + async fn cancel(&mut self) { - poll_fn(move |cx| self.0.lock().unwrap().poll_cancel(cx)).await; + self.0.lock().await.cancel().await } } #[async_trait::async_trait] impl Subscribe for TcpWriteStream { async fn ready(&mut self) { - poll_fn(move |cx| self.0.lock().unwrap().poll_ready(cx)).await; + self.0.lock().await.ready().await } } @@ -988,3 +984,17 @@ fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) { .as_socketlike_view::() .shutdown(how); } + +fn try_lock_for_stream(mutex: &Mutex) -> Result, StreamError> { + mutex + .try_lock() + .map_err(|_| StreamError::trap("concurrent access to resource not supported")) +} + +fn try_lock_for_socket(mutex: &Mutex) -> Result, SocketError> { + mutex.try_lock().map_err(|_| { + SocketError::trap(anyhow::anyhow!( + "concurrent access to resource not supported" + )) + }) +}