diff --git a/crates/wasi/src/stdio.rs b/crates/wasi/src/stdio.rs index 2a4aae18db96..a8865927d95a 100644 --- a/crates/wasi/src/stdio.rs +++ b/crates/wasi/src/stdio.rs @@ -8,11 +8,9 @@ use crate::{ HostInputStream, HostOutputStream, StreamError, StreamResult, Subscribe, WasiImpl, WasiView, }; use bytes::Bytes; -use std::future::Future; use std::io::IsTerminal; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; +use std::sync::Arc; +use tokio::sync::Mutex; use wasmtime::component::Resource; /// A trait used to represent the standard input to a guest program. @@ -61,6 +59,33 @@ impl StdinStream for pipe::ClosedInputStream { } /// An impl of [`StdinStream`] built on top of [`crate::pipe::AsyncReadStream`]. +// +// Note the usage of `tokio::sync::Mutex` here as opposed to a +// `std::sync::Mutex`. This is intentionally done to implement the `Subscribe` +// variant of this trait. Note that in doing so we're left with the quandry of +// how to implement methods of `HostInputStream` since those methods are not +// `async`. They're currently implemented with `try_lock`, which then raises the +// question of what to do on contention. Currently traps are returned. +// +// Why should it be ok to return a trap? In general concurrency/contention +// shouldn't return a trap since it should be able to happen normally. The +// current assumption, though, is that WASI stdin/stdout streams are special +// enough that the contention case should never come up in practice. Currently +// in WASI there is no actually concurrency, there's just the items in a single +// `Store` and that store owns all of its I/O in a single Tokio task. There's no +// means to actually spawn multiple Tokio tasks that use the same store. This +// means at the very least that there's zero parallelism. Due to the lack of +// multiple tasks that also means that there's no concurrency either. +// +// This `AsyncStdinStream` wrapper is only intended to be used by the WASI +// bindings themselves. It's possible for the host to take this and work with it +// on its own task, but that's niche enough it's not designed for. +// +// Overall that means that the guest is either calling `Subscribe` or it's +// calling `HostInputStream` methods. This means that there should never be +// contention between the two at this time. This may all change in the future +// with WASI 0.3, but perhaps we'll have a better story for stdio at that time +// (see the doc block on the `HostOutputStream` impl below) pub struct AsyncStdinStream(Arc>); impl AsyncStdinStream { @@ -79,30 +104,24 @@ impl StdinStream for AsyncStdinStream { } impl HostInputStream for AsyncStdinStream { - fn read(&mut self, size: usize) -> Result { - self.0.lock().unwrap().read(size) + fn read(&mut self, size: usize) -> Result { + match self.0.try_lock() { + Ok(mut stream) => stream.read(size), + Err(_) => Err(StreamError::trap("concurrent reads are not supported")), + } } - fn skip(&mut self, size: usize) -> Result { - self.0.lock().unwrap().skip(size) + fn skip(&mut self, size: usize) -> Result { + match self.0.try_lock() { + Ok(mut stream) => stream.skip(size), + Err(_) => Err(StreamError::trap("concurrent skips are not supported")), + } } } +#[async_trait::async_trait] impl Subscribe for AsyncStdinStream { - fn ready<'a, 'b>(&'a mut self) -> Pin + Send + 'b>> - where - Self: 'b, - 'a: 'b, - { - struct F(AsyncStdinStream); - impl Future for F { - type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut inner = self.0 .0.lock().unwrap(); - let mut fut = inner.ready(); - fut.as_mut().poll(cx) - } - } - Box::pin(F(Self(self.0.clone()))) + async fn ready(&mut self) { + self.0.lock().await.ready().await } } @@ -300,6 +319,10 @@ impl Subscribe for OutputStream { /// A wrapper of [`crate::pipe::AsyncWriteStream`] that implements /// [`StdoutStream`]. Note that the [`HostOutputStream`] impl for this is not /// correct when used for interleaved async IO. +// +// Note that the use of `tokio::sync::Mutex` here is intentional, in addition to +// the `try_lock()` calls below in the implementation of `HostOutputStream`. For +// more information see the documentation on `AsyncStdinStream`. pub struct AsyncStdoutStream(Arc>); impl AsyncStdoutStream { @@ -334,32 +357,29 @@ impl StdoutStream for AsyncStdoutStream { // this comment to correct it: sorry about that. impl HostOutputStream for AsyncStdoutStream { fn check_write(&mut self) -> Result { - self.0.lock().unwrap().check_write() + match self.0.try_lock() { + Ok(mut stream) => stream.check_write(), + Err(_) => Err(StreamError::trap("concurrent writes are not supported")), + } } fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> { - self.0.lock().unwrap().write(bytes) + match self.0.try_lock() { + Ok(mut stream) => stream.write(bytes), + Err(_) => Err(StreamError::trap("concurrent writes not supported yet")), + } } fn flush(&mut self) -> Result<(), StreamError> { - self.0.lock().unwrap().flush() + match self.0.try_lock() { + Ok(mut stream) => stream.flush(), + Err(_) => Err(StreamError::trap("concurrent flushes not supported yet")), + } } } +#[async_trait::async_trait] impl Subscribe for AsyncStdoutStream { - fn ready<'a, 'b>(&'a mut self) -> Pin + Send + 'b>> - where - Self: 'b, - 'a: 'b, - { - struct F(AsyncStdoutStream); - impl Future for F { - type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut inner = self.0 .0.lock().unwrap(); - let mut fut = inner.ready(); - fut.as_mut().poll(cx) - } - } - Box::pin(F(Self(self.0.clone()))) + async fn ready(&mut self) { + self.0.lock().await.ready().await } } @@ -464,6 +484,13 @@ where #[cfg(test)] mod test { + use crate::stdio::StdoutStream; + use crate::write_stream::AsyncWriteStream; + use crate::{AsyncStdoutStream, HostOutputStream}; + use anyhow::Result; + use bytes::Bytes; + use tokio::io::AsyncReadExt; + #[test] fn memory_stdin_stream() { // A StdinStream has the property that there are multiple @@ -492,6 +519,7 @@ mod test { let read4 = view2.read(10).expect("read fourth 10 bytes"); assert_eq!(read4, "r the thre".as_bytes(), "fourth 10 bytes"); } + #[tokio::test] async fn async_stdin_stream() { // A StdinStream has the property that there are multiple @@ -530,4 +558,39 @@ mod test { let read4 = view2.read(10).expect("read fourth 10 bytes"); assert_eq!(read4, "r the thre".as_bytes(), "fourth 10 bytes"); } + + #[tokio::test] + async fn async_stdout_stream_unblocks() { + let (mut read, write) = tokio::io::duplex(32); + let stdout = AsyncStdoutStream::new(AsyncWriteStream::new(32, write)); + + let task = tokio::task::spawn(async move { + let mut stream = stdout.stream(); + blocking_write_and_flush(&mut *stream, "x".into()) + .await + .unwrap(); + }); + + let mut buf = [0; 100]; + let n = read.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"x"); + + task.await.unwrap(); + } + + async fn blocking_write_and_flush( + s: &mut dyn HostOutputStream, + mut bytes: Bytes, + ) -> Result<()> { + while !bytes.is_empty() { + let permit = s.write_ready().await?; + let len = bytes.len().min(permit); + let chunk = bytes.split_to(len); + s.write(chunk)?; + } + + s.flush()?; + s.write_ready().await?; + Ok(()) + } }