diff --git a/crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs b/crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs index 03539ea35de5..7ca127865f6d 100644 --- a/crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs +++ b/crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs @@ -153,23 +153,29 @@ unsafe fn test_fd_readwrite(readable_fd: wasi::Fd, writable_fd: wasi::Fd, error_ ]; let out = poll_oneoff_with_retry(&r#in).unwrap(); assert_eq!(out.len(), 2, "should return 2 events, got: {:?}", out); + + let (read, write) = if out[0].userdata == 1 { + (&out[0], &out[1]) + } else { + (&out[1], &out[0]) + }; assert_eq!( - out[0].userdata, 1, + read.userdata, 1, "the event.userdata should contain fd userdata specified by the user" ); - assert_errno!(out[0].error, error_code); + assert_errno!(read.error, error_code); assert_eq!( - out[0].type_, + read.type_, wasi::EVENTTYPE_FD_READ, "the event.type_ should equal FD_READ" ); assert_eq!( - out[1].userdata, 2, + write.userdata, 2, "the event.userdata should contain fd userdata specified by the user" ); - assert_errno!(out[1].error, error_code); + assert_errno!(write.error, error_code); assert_eq!( - out[1].type_, + write.type_, wasi::EVENTTYPE_FD_WRITE, "the event.type_ should equal FD_WRITE" ); diff --git a/crates/wasi-http/src/body.rs b/crates/wasi-http/src/body.rs index fa021a3ddd61..9d27832049be 100644 --- a/crates/wasi-http/src/body.rs +++ b/crates/wasi-http/src/body.rs @@ -11,7 +11,7 @@ use std::{ use tokio::sync::{mpsc, oneshot}; use wasmtime_wasi::preview2::{ self, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, OutputStreamError, - StreamRuntimeError, StreamState, + StreamRuntimeError, StreamState, Subscribe, }; pub type HyperIncomingBody = BoxBody; @@ -189,14 +189,17 @@ impl HostInputStream for HostIncomingBodyStream { } } } +} - async fn ready(&mut self) -> anyhow::Result<()> { +#[async_trait::async_trait] +impl Subscribe for HostIncomingBodyStream { + async fn ready(&mut self) { if !self.buffer.is_empty() { - return Ok(()); + return; } if !self.open { - return Ok(()); + return; } match self.receiver.recv().await { @@ -209,8 +212,6 @@ impl HostInputStream for HostIncomingBodyStream { None => self.open = false, } - - Ok(()) } } @@ -224,8 +225,9 @@ pub enum HostFutureTrailersState { Done(Result), } -impl HostFutureTrailers { - pub async fn ready(&mut self) -> anyhow::Result<()> { +#[async_trait::async_trait] +impl Subscribe for HostFutureTrailers { + async fn ready(&mut self) { if let HostFutureTrailersState::Waiting(rx) = &mut self.state { let result = match rx.await { Ok(Ok(headers)) => Ok(FieldMap::from(headers)), @@ -236,7 +238,6 @@ impl HostFutureTrailers { }; self.state = HostFutureTrailersState::Done(result); } - Ok(()) } } @@ -353,11 +354,6 @@ enum Job { Write(Bytes), } -enum WriteStatus<'a> { - Done(Result), - Pending(tokio::sync::futures::Notified<'a>), -} - impl Worker { fn new(write_budget: usize) -> Self { Self { @@ -372,17 +368,31 @@ impl Worker { write_ready_changed: tokio::sync::Notify::new(), } } - fn check_write(&self) -> WriteStatus<'_> { + async fn ready(&self) { + loop { + { + let state = self.state(); + if state.error.is_some() + || !state.alive + || (!state.flush_pending && state.write_budget > 0) + { + return; + } + } + self.write_ready_changed.notified().await; + } + } + fn check_write(&self) -> Result { let mut state = self.state(); if let Err(e) = state.check_error() { - return WriteStatus::Done(Err(e)); + return Err(e); } if state.flush_pending || state.write_budget == 0 { - return WriteStatus::Pending(self.write_ready_changed.notified()); + return Ok(0); } - WriteStatus::Done(Ok(state.write_budget)) + Ok(state.write_budget) } fn state(&self) -> std::sync::MutexGuard { self.state.lock().unwrap() @@ -496,12 +506,13 @@ impl HostOutputStream for BodyWriteStream { Ok(()) } - async fn write_ready(&mut self) -> Result { - loop { - match self.worker.check_write() { - WriteStatus::Done(r) => return r, - WriteStatus::Pending(notifier) => notifier.await, - } - } + fn check_write(&mut self) -> Result { + self.worker.check_write() + } +} +#[async_trait::async_trait] +impl Subscribe for BodyWriteStream { + async fn ready(&mut self) { + self.worker.ready().await } } diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 391c5d91851e..2b6ff34b1056 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -12,10 +12,8 @@ use crate::{ }, }; use std::any::Any; -use std::pin::Pin; -use std::task; use wasmtime::component::Resource; -use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Table, TableError}; +use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Subscribe, Table, TableError}; /// Capture the state necessary for use in the wasi-http API implementation. pub struct WasiHttpCtx; @@ -167,21 +165,11 @@ impl HostFutureIncomingResponse { } } -impl std::future::Future for HostFutureIncomingResponse { - type Output = anyhow::Result<()>; - - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll { - let s = self.get_mut(); - match s { - Self::Pending(ref mut handle) => match Pin::new(handle).poll(cx) { - task::Poll::Pending => task::Poll::Pending, - task::Poll::Ready(r) => { - *s = Self::Ready(r); - task::Poll::Ready(Ok(())) - } - }, - - Self::Consumed | Self::Ready(_) => task::Poll::Ready(Ok(())), +#[async_trait::async_trait] +impl Subscribe for HostFutureIncomingResponse { + async fn ready(&mut self) { + if let Self::Pending(handle) = self { + *self = Self::Ready(handle.await); } } } diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index af91caea411d..a470ece0ef3d 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -18,7 +18,7 @@ use std::any::Any; use wasmtime::component::Resource; use wasmtime_wasi::preview2::{ bindings::io::streams::{InputStream, OutputStream}, - Pollable, PollableFuture, + Pollable, }; impl crate::bindings::http::types::Host for T { @@ -352,19 +352,10 @@ impl crate::bindings::http::types::Host for T { &mut self, index: FutureTrailers, ) -> wasmtime::Result> { - // Eagerly force errors about the validity of the index. - let _ = self.table().get_future_trailers(index)?; - - fn make_future(elem: &mut dyn Any) -> PollableFuture { - Box::pin(elem.downcast_mut::().unwrap().ready()) - } - - // FIXME: this should use `push_child_resource` - let id = self - .table() - .push_resource(Pollable::TableEntry { index, make_future })?; - - Ok(id) + wasmtime_wasi::preview2::subscribe( + self.table(), + Resource::::new_borrow(index), + ) } fn future_trailers_get( @@ -480,22 +471,10 @@ impl crate::bindings::http::types::Host for T { &mut self, id: FutureIncomingResponse, ) -> wasmtime::Result> { - let _ = self.table().get_future_incoming_response(id)?; - - fn make_future<'a>(elem: &'a mut dyn Any) -> PollableFuture<'a> { - Box::pin( - elem.downcast_mut::() - .expect("parent resource is HostFutureIncomingResponse"), - ) - } - - // FIXME: this should use `push_child_resource` - let pollable = self.table().push_resource(Pollable::TableEntry { - index: id, - make_future, - })?; - - Ok(pollable) + wasmtime_wasi::preview2::subscribe( + self.table(), + Resource::::new_borrow(id), + ) } fn outgoing_body_write( diff --git a/crates/wasi/src/preview2/filesystem.rs b/crates/wasi/src/preview2/filesystem.rs index c360b874c801..e09d32df1d9c 100644 --- a/crates/wasi/src/preview2/filesystem.rs +++ b/crates/wasi/src/preview2/filesystem.rs @@ -1,10 +1,12 @@ use crate::preview2::bindings::filesystem::types; use crate::preview2::{ AbortOnDropJoinHandle, HostOutputStream, OutputStreamError, StreamRuntimeError, StreamState, + Subscribe, }; use anyhow::anyhow; use bytes::{Bytes, BytesMut}; -use futures::future::{maybe_done, MaybeDone}; +use std::io; +use std::mem; use std::sync::Arc; pub enum Descriptor { @@ -156,11 +158,11 @@ impl FileInputStream { } } -fn read_result(r: Result) -> Result<(usize, StreamState), anyhow::Error> { +fn read_result(r: io::Result) -> Result<(usize, StreamState), anyhow::Error> { match r { Ok(0) => Ok((0, StreamState::Closed)), Ok(n) => Ok((n, StreamState::Open)), - Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok((0, StreamState::Open)), + Err(e) if e.kind() == io::ErrorKind::Interrupted => Ok((0, StreamState::Open)), Err(e) => Err(StreamRuntimeError::from(anyhow!(e)).into()), } } @@ -174,26 +176,32 @@ pub(crate) enum FileOutputMode { pub(crate) struct FileOutputStream { file: Arc, mode: FileOutputMode, - // Allows join future to be awaited in a cancellable manner. Gone variant indicates - // no task is currently outstanding. - task: MaybeDone>>, - closed: bool, + state: OutputState, } + +enum OutputState { + Ready, + /// Allows join future to be awaited in a cancellable manner. Gone variant indicates + /// no task is currently outstanding. + Waiting(AbortOnDropJoinHandle>), + /// The last I/O operation failed with this error. + Error(io::Error), + Closed, +} + impl FileOutputStream { pub fn write_at(file: Arc, position: u64) -> Self { Self { file, mode: FileOutputMode::Position(position), - task: MaybeDone::Gone, - closed: false, + state: OutputState::Ready, } } pub fn append(file: Arc) -> Self { Self { file, mode: FileOutputMode::Append, - task: MaybeDone::Gone, - closed: false, + state: OutputState::Ready, } } } @@ -201,74 +209,79 @@ impl FileOutputStream { // FIXME: configurable? determine from how much space left in file? const FILE_WRITE_CAPACITY: usize = 1024 * 1024; -#[async_trait::async_trait] impl HostOutputStream for FileOutputStream { fn write(&mut self, buf: Bytes) -> Result<(), OutputStreamError> { use system_interface::fs::FileIoExt; - - if self.closed { - return Err(OutputStreamError::Closed); - } - if !matches!(self.task, MaybeDone::Gone) { - // a write is pending - this call was not permitted - return Err(OutputStreamError::Trap(anyhow!( - "write not permitted: FileOutputStream write pending" - ))); + match self.state { + OutputState::Ready => {} + OutputState::Closed => return Err(OutputStreamError::Closed), + OutputState::Waiting(_) | OutputState::Error(_) => { + // a write is pending - this call was not permitted + return Err(OutputStreamError::Trap(anyhow!( + "write not permitted: check_write not called first" + ))); + } } + let f = Arc::clone(&self.file); let m = self.mode; - self.task = maybe_done(AbortOnDropJoinHandle::from(tokio::task::spawn_blocking( - move || match m { - FileOutputMode::Position(mut p) => { - let mut buf = buf; - while !buf.is_empty() { - let nwritten = f.write_at(buf.as_ref(), p)?; - // afterwards buf contains [nwritten, len): - let _ = buf.split_to(nwritten); - p += nwritten as u64; - } - Ok(()) + let task = AbortOnDropJoinHandle::from(tokio::task::spawn_blocking(move || match m { + FileOutputMode::Position(mut p) => { + let mut buf = buf; + while !buf.is_empty() { + let nwritten = f.write_at(buf.as_ref(), p)?; + // afterwards buf contains [nwritten, len): + let _ = buf.split_to(nwritten); + p += nwritten as u64; } - FileOutputMode::Append => { - let mut buf = buf; - while !buf.is_empty() { - let nwritten = f.append(buf.as_ref())?; - let _ = buf.split_to(nwritten); - } - Ok(()) + Ok(()) + } + FileOutputMode::Append => { + let mut buf = buf; + while !buf.is_empty() { + let nwritten = f.append(buf.as_ref())?; + let _ = buf.split_to(nwritten); } - }, - ))); + Ok(()) + } + })); + self.state = OutputState::Waiting(task); Ok(()) } fn flush(&mut self) -> Result<(), OutputStreamError> { - if self.closed { - return Err(OutputStreamError::Closed); + match self.state { + // Only userland buffering of file writes is in the blocking task, + // so there's nothing extra that needs to be done to request a + // flush. + OutputState::Ready | OutputState::Waiting(_) => Ok(()), + OutputState::Closed => Err(OutputStreamError::Closed), + OutputState::Error(_) => match mem::replace(&mut self.state, OutputState::Closed) { + OutputState::Error(e) => Err(OutputStreamError::LastOperationFailed(e.into())), + _ => unreachable!(), + }, } - // Only userland buffering of file writes is in the blocking task. - Ok(()) } - async fn write_ready(&mut self) -> Result { - if self.closed { - return Err(OutputStreamError::Closed); - } - // If there is no outstanding task, accept more input: - if matches!(self.task, MaybeDone::Gone) { - return Ok(FILE_WRITE_CAPACITY); + fn check_write(&mut self) -> Result { + match self.state { + OutputState::Ready => Ok(FILE_WRITE_CAPACITY), + OutputState::Closed => Err(OutputStreamError::Closed), + OutputState::Error(_) => match mem::replace(&mut self.state, OutputState::Closed) { + OutputState::Error(e) => Err(OutputStreamError::LastOperationFailed(e.into())), + _ => unreachable!(), + }, + OutputState::Waiting(_) => Ok(0), } - // Wait for outstanding task: - std::pin::Pin::new(&mut self.task).await; + } +} - // Mark task as finished, and handle output: - match std::pin::Pin::new(&mut self.task) - .take_output() - .expect("just awaited for MaybeDone completion") - { - Ok(()) => Ok(FILE_WRITE_CAPACITY), - Err(e) => { - self.closed = true; - Err(OutputStreamError::LastOperationFailed(e.into())) - } +#[async_trait::async_trait] +impl Subscribe for FileOutputStream { + async fn ready(&mut self) { + if let OutputState::Waiting(task) = &mut self.state { + self.state = match task.await { + Ok(()) => OutputState::Ready, + Err(e) => OutputState::Error(e), + }; } } } diff --git a/crates/wasi/src/preview2/host/clocks.rs b/crates/wasi/src/preview2/host/clocks.rs index 15adedf5770d..105f8b3ee47f 100644 --- a/crates/wasi/src/preview2/host/clocks.rs +++ b/crates/wasi/src/preview2/host/clocks.rs @@ -5,8 +5,10 @@ use crate::preview2::bindings::{ clocks::timezone::{self, TimezoneDisplay}, clocks::wall_clock::{self, Datetime}, }; +use crate::preview2::poll::{subscribe, Subscribe}; use crate::preview2::{Pollable, WasiView}; use cap_std::time::SystemTime; +use std::time::Duration; use wasmtime::component::Resource; impl TryFrom for Datetime { @@ -51,42 +53,29 @@ impl monotonic_clock::Host for T { } fn subscribe(&mut self, when: Instant, absolute: bool) -> anyhow::Result> { - use std::time::Duration; - // Calculate time relative to clock object, which may not have the same zero - // point as tokio Inst::now() let clock_now = self.ctx().monotonic_clock.now(); - if absolute && when < clock_now { - // Deadline is in the past, so pollable is always ready: - Ok(self - .table_mut() - .push_resource(Pollable::Closure(Box::new(|| Box::pin(async { Ok(()) }))))?) + let duration = if absolute { + Duration::from_nanos(when - clock_now) } else { - let duration = if absolute { - Duration::from_nanos(when - clock_now) - } else { - Duration::from_nanos(when) - }; - let deadline = tokio::time::Instant::now() - .checked_add(duration) - .ok_or_else(|| anyhow::anyhow!("time overflow: duration {duration:?}"))?; - tracing::trace!( - "deadline = {:?}, now = {:?}", - deadline, - tokio::time::Instant::now() - ); - Ok(self - .table_mut() - .push_resource(Pollable::Closure(Box::new(move || { - Box::pin(async move { - tracing::trace!( - "mkf: deadline = {:?}, now = {:?}", - deadline, - tokio::time::Instant::now() - ); - Ok(tokio::time::sleep_until(deadline).await) - }) - })))?) - } + Duration::from_nanos(when) + }; + let deadline = tokio::time::Instant::now() + .checked_add(duration) + .ok_or_else(|| anyhow::anyhow!("time overflow: duration {duration:?}"))?; + // NB: this resource created here is not actually exposed to wasm, it's + // only an internal implementation detail used to match the signature + // expected by `subscribe`. + let sleep = self.table_mut().push_resource(Sleep(deadline))?; + subscribe(self.table_mut(), sleep) + } +} + +struct Sleep(tokio::time::Instant); + +#[async_trait::async_trait] +impl Subscribe for Sleep { + async fn ready(&mut self) { + tokio::time::sleep_until(self.0).await; } } diff --git a/crates/wasi/src/preview2/host/io.rs b/crates/wasi/src/preview2/host/io.rs index 504fccac1875..a9e18de4caf3 100644 --- a/crates/wasi/src/preview2/host/io.rs +++ b/crates/wasi/src/preview2/host/io.rs @@ -1,13 +1,9 @@ use crate::preview2::{ bindings::io::streams::{self, InputStream, OutputStream}, - poll::PollableFuture, + poll::subscribe, stream::{OutputStreamError, StreamRuntimeError, StreamState}, Pollable, TableError, WasiView, }; -use std::any::Any; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; use wasmtime::component::Resource; impl From for streams::StreamStatus { @@ -48,14 +44,8 @@ impl streams::HostOutputStream for T { } fn check_write(&mut self, stream: Resource) -> Result { - let s = self.table_mut().get_resource_mut(&stream)?; - let mut ready = s.write_ready(); - let mut task = Context::from_waker(futures::task::noop_waker_ref()); - match Pin::new(&mut ready).poll(&mut task) { - Poll::Ready(Ok(permit)) => Ok(permit as u64), - Poll::Ready(Err(e)) => Err(e.into()), - Poll::Pending => Ok(0), - } + let bytes = self.table_mut().get_resource_mut(&stream)?.check_write()?; + Ok(bytes as u64) } fn write( @@ -70,23 +60,7 @@ impl streams::HostOutputStream for T { } fn subscribe(&mut self, stream: Resource) -> anyhow::Result> { - fn output_stream_ready<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { - let stream = stream - .downcast_mut::() - .expect("downcast to OutputStream failed"); - Box::pin(async move { - let _ = stream.write_ready().await?; - Ok(()) - }) - } - - Ok(self.table_mut().push_child_resource( - Pollable::TableEntry { - index: stream.rep(), - make_future: output_stream_ready, - }, - &stream, - )?) + subscribe(self.table_mut(), stream) } async fn blocking_write_and_flush( @@ -291,7 +265,7 @@ impl streams::HostInputStream for T { len: u64, ) -> anyhow::Result, streams::StreamStatus), ()>> { if let InputStream::Host(s) = self.table_mut().get_resource_mut(&stream)? { - s.ready().await?; + s.ready().await; } self.read(stream, len).await } @@ -341,37 +315,13 @@ impl streams::HostInputStream for T { len: u64, ) -> anyhow::Result> { if let InputStream::Host(s) = self.table_mut().get_resource_mut(&stream)? { - s.ready().await?; + s.ready().await; } self.skip(stream, len).await } fn subscribe(&mut self, stream: Resource) -> anyhow::Result> { - // Ensure that table element is an input-stream: - let pollable = match self.table_mut().get_resource(&stream)? { - InputStream::Host(_) => { - fn input_stream_ready<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { - let stream = stream - .downcast_mut::() - .expect("downcast to InputStream failed"); - match *stream { - InputStream::Host(ref mut hs) => hs.ready(), - _ => unreachable!(), - } - } - - Pollable::TableEntry { - index: stream.rep(), - make_future: input_stream_ready, - } - } - // Files are always "ready" immediately (because we have no way to actually wait on - // readiness in epoll) - InputStream::File(_) => { - Pollable::Closure(Box::new(|| Box::pin(futures::future::ready(Ok(()))))) - } - }; - Ok(self.table_mut().push_child_resource(pollable, &stream)?) + crate::preview2::poll::subscribe(self.table_mut(), stream) } } diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index 185807504fd2..fe4e2c164c2d 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -4,13 +4,12 @@ use crate::preview2::bindings::{ sockets::tcp::{self, ShutdownType}, }; use crate::preview2::tcp::{TcpSocket, TcpState}; -use crate::preview2::{Pollable, PollableFuture, WasiView}; +use crate::preview2::{Pollable, WasiView}; use cap_net_ext::{Blocking, PoolExt, TcpListenerExt}; use cap_std::net::TcpListener; use io_lifetimes::AsSocketlike; use rustix::io::Errno; use rustix::net::sockopt; -use std::any::Any; use tokio::io::Interest; use wasmtime::component::Resource; @@ -440,38 +439,7 @@ impl crate::preview2::host::tcp::tcp::HostTcpSocket for T { } fn subscribe(&mut self, this: Resource) -> anyhow::Result> { - fn make_tcp_socket_future<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { - let socket = stream - .downcast_mut::() - .expect("downcast to TcpSocket failed"); - - // Some states are ready immediately. - match socket.tcp_state { - TcpState::BindStarted | TcpState::ListenStarted | TcpState::ConnectReady => { - return Box::pin(async { Ok(()) }) - } - _ => {} - } - - // FIXME: Add `Interest::ERROR` when we update to tokio 1.32. - let join = Box::pin(async move { - socket - .inner - .ready(Interest::READABLE | Interest::WRITABLE) - .await - .unwrap(); - Ok(()) - }); - - join - } - - let pollable = Pollable::TableEntry { - index: this.rep(), - make_future: make_tcp_socket_future, - }; - - Ok(self.table_mut().push_child_resource(pollable, &this)?) + crate::preview2::poll::subscribe(self.table_mut(), this) } fn shutdown( diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index e3f171a11973..de8dd801dcd2 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -15,6 +15,10 @@ //! `pub mod legacy` with an off-by-default feature flag, and after 2 //! releases, retire and remove that code from our tree. +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + mod clocks; pub mod command; mod ctx; @@ -37,7 +41,7 @@ pub use self::clocks::{HostMonotonicClock, HostWallClock}; pub use self::ctx::{WasiCtx, WasiCtxBuilder, WasiView}; pub use self::error::I32Exit; pub use self::filesystem::{DirPerms, FilePerms}; -pub use self::poll::{ClosureFuture, MakeFuture, Pollable, PollableFuture}; +pub use self::poll::{subscribe, ClosureFuture, MakeFuture, Pollable, PollableFuture, Subscribe}; pub use self::random::{thread_rng, Deterministic}; pub use self::stdio::{stderr, stdin, stdout, IsATTY, Stderr, Stdin, Stdout}; pub use self::stream::{ @@ -193,14 +197,9 @@ impl From> for AbortOnDropJoinHandle { AbortOnDropJoinHandle(jh) } } -impl std::future::Future for AbortOnDropJoinHandle { +impl Future for AbortOnDropJoinHandle { type Output = T; - fn poll( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - use std::pin::Pin; - use std::task::Poll; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match Pin::new(&mut self.as_mut().0).poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(r) => Poll::Ready(r.expect("child task panicked")), @@ -210,7 +209,7 @@ impl std::future::Future for AbortOnDropJoinHandle { pub fn spawn(f: F) -> AbortOnDropJoinHandle where - F: std::future::Future + Send + 'static, + F: Future + Send + 'static, G: Send + 'static, { let j = match tokio::runtime::Handle::try_current() { @@ -223,7 +222,7 @@ where AbortOnDropJoinHandle(j) } -pub fn in_tokio(f: F) -> F::Output { +pub fn in_tokio(f: F) -> F::Output { match tokio::runtime::Handle::try_current() { Ok(h) => { let _enter = h.enter(); @@ -245,3 +244,14 @@ fn with_ambient_tokio_runtime(f: impl FnOnce() -> R) -> R { } } } + +fn poll_noop(future: Pin<&mut F>) -> Option +where + F: Future, +{ + let mut task = Context::from_waker(futures::task::noop_waker_ref()); + match future.poll(&mut task) { + Poll::Ready(result) => Some(result), + Poll::Pending => None, + } +} diff --git a/crates/wasi/src/preview2/pipe.rs b/crates/wasi/src/preview2/pipe.rs index c261760d7d93..b99e68a28524 100644 --- a/crates/wasi/src/preview2/pipe.rs +++ b/crates/wasi/src/preview2/pipe.rs @@ -7,6 +7,7 @@ //! Some convenience constructors are included for common backing types like `Vec` and `String`, //! but the virtual pipes can be instantiated with any `Read` or `Write` type. //! +use crate::preview2::poll::Subscribe; use crate::preview2::{HostInputStream, HostOutputStream, OutputStreamError, StreamState}; use anyhow::{anyhow, Error}; use bytes::Bytes; @@ -49,10 +50,11 @@ impl HostInputStream for MemoryInputPipe { }; Ok((read, state)) } +} - async fn ready(&mut self) -> Result<(), Error> { - Ok(()) - } +#[async_trait::async_trait] +impl Subscribe for MemoryInputPipe { + async fn ready(&mut self) {} } #[derive(Debug, Clone)] @@ -78,7 +80,6 @@ impl MemoryOutputPipe { } } -#[async_trait::async_trait] impl HostOutputStream for MemoryOutputPipe { fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { let mut buf = self.buffer.lock().unwrap(); @@ -95,7 +96,7 @@ impl HostOutputStream for MemoryOutputPipe { // This stream is always flushed Ok(()) } - async fn write_ready(&mut self) -> Result { + fn check_write(&mut self) -> Result { let consumed = self.buffer.lock().unwrap().len(); if consumed < self.capacity { Ok(self.capacity - consumed) @@ -106,6 +107,11 @@ impl HostOutputStream for MemoryOutputPipe { } } +#[async_trait::async_trait] +impl Subscribe for MemoryOutputPipe { + async fn ready(&mut self) {} +} + /// Provides a [`HostInputStream`] impl from a [`tokio::io::AsyncRead`] impl pub struct AsyncReadStream { state: StreamState, @@ -189,10 +195,12 @@ impl HostInputStream for AsyncReadStream { )), } } - - async fn ready(&mut self) -> Result<(), Error> { +} +#[async_trait::async_trait] +impl Subscribe for AsyncReadStream { + async fn ready(&mut self) { if self.buffer.is_some() || self.state == StreamState::Closed { - return Ok(()); + return; } match self.receiver.recv().await { Some(Ok((bytes, state))) => { @@ -203,12 +211,9 @@ impl HostInputStream for AsyncReadStream { } Some(Err(e)) => self.buffer = Some(Err(e)), None => { - return Err(anyhow!( - "no more sender for an open AsyncReadStream - should be impossible" - )) + panic!("no more sender for an open AsyncReadStream - should be impossible") } } - Ok(()) } } @@ -216,7 +221,6 @@ impl HostInputStream for AsyncReadStream { #[derive(Copy, Clone)] pub struct SinkOutputStream; -#[async_trait::async_trait] impl HostOutputStream for SinkOutputStream { fn write(&mut self, _buf: Bytes) -> Result<(), OutputStreamError> { Ok(()) @@ -226,12 +230,17 @@ impl HostOutputStream for SinkOutputStream { Ok(()) } - async fn write_ready(&mut self) -> Result { + fn check_write(&mut self) -> Result { // This stream is always ready for writing. Ok(usize::MAX) } } +#[async_trait::async_trait] +impl Subscribe for SinkOutputStream { + async fn ready(&mut self) {} +} + /// A stream that is ready immediately, but will always report that it's closed. #[derive(Copy, Clone)] pub struct ClosedInputStream; @@ -241,17 +250,17 @@ impl HostInputStream for ClosedInputStream { fn read(&mut self, _size: usize) -> Result<(Bytes, StreamState), Error> { Ok((Bytes::new(), StreamState::Closed)) } +} - async fn ready(&mut self) -> Result<(), Error> { - Ok(()) - } +#[async_trait::async_trait] +impl Subscribe for ClosedInputStream { + async fn ready(&mut self) {} } /// An output stream that is always closed. #[derive(Copy, Clone)] pub struct ClosedOutputStream; -#[async_trait::async_trait] impl HostOutputStream for ClosedOutputStream { fn write(&mut self, _: Bytes) -> Result<(), OutputStreamError> { Err(OutputStreamError::Closed) @@ -260,11 +269,16 @@ impl HostOutputStream for ClosedOutputStream { Err(OutputStreamError::Closed) } - async fn write_ready(&mut self) -> Result { + fn check_write(&mut self) -> Result { Err(OutputStreamError::Closed) } } +#[async_trait::async_trait] +impl Subscribe for ClosedOutputStream { + async fn ready(&mut self) {} +} + #[cfg(test)] mod test { use super::*; @@ -323,9 +337,7 @@ mod test { // The reader task hasn't run yet. Call `ready` to await and fill the buffer. StreamState::Open => { - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; let (bs, state) = reader.read(0).unwrap(); assert!(bs.is_empty()); assert_eq!(state, StreamState::Closed); @@ -341,9 +353,7 @@ mod test { assert_eq!(state, StreamState::Open); if bs.is_empty() { // Reader task hasn't run yet. Call `ready` to await and fill the buffer. - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Now a read should succeed let (bs, state) = reader.read(10).unwrap(); assert_eq!(bs.len(), 10); @@ -377,9 +387,7 @@ mod test { assert_eq!(state, StreamState::Open); if bs.is_empty() { // Reader task hasn't run yet. Call `ready` to await and fill the buffer. - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Now a read should succeed let (bs, state) = reader.read(123).unwrap(); assert_eq!(bs.len(), 123); @@ -396,9 +404,7 @@ mod test { StreamState::Closed => {} // Correct! StreamState::Open => { // Need to await to give this side time to catch up - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Now a read should show closed let (bs, state) = reader.read(0).unwrap(); assert_eq!(bs.len(), 0); @@ -420,9 +426,7 @@ mod test { assert_eq!(state, StreamState::Open); if bs.is_empty() { // Reader task hasn't run yet. Call `ready` to await and fill the buffer. - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Now a read should succeed let (bs, state) = reader.read(1).unwrap(); assert_eq!(*bs, [123u8]); @@ -449,9 +453,7 @@ mod test { // Wait readiness (yes we could possibly win the race and read it out faster, leaving that // out of the test for simplicity) - resolves_immediately(reader.ready()) - .await - .expect("the ready is ok"); + resolves_immediately(reader.ready()).await; // read the something else back out: let (bs, state) = reader.read(1).unwrap(); @@ -476,9 +478,7 @@ mod test { // Wait readiness (yes we could possibly win the race and read it out faster, leaving that // out of the test for simplicity) - resolves_immediately(reader.ready()) - .await - .expect("the ready is ok"); + resolves_immediately(reader.ready()).await; // empty and now closed: let (bs, state) = reader.read(1).unwrap(); @@ -500,9 +500,7 @@ mod test { w }); - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Now we expect the reader task has sent 4k from the stream to the reader. // Try to read out one bigger than the buffer available: @@ -511,9 +509,7 @@ mod test { assert_eq!(state, StreamState::Open); // Allow the crank to turn more: - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Again we expect the reader task has sent 4k from the stream to the reader. // Try to read out one bigger than the buffer available: @@ -528,9 +524,7 @@ mod test { drop(w); // Allow the crank to turn more: - resolves_immediately(reader.ready()) - .await - .expect("ready is ok"); + resolves_immediately(reader.ready()).await; // Now we expect the reader to be empty, and the stream closed: let (bs, state) = reader.read(4097).unwrap(); diff --git a/crates/wasi/src/preview2/poll.rs b/crates/wasi/src/preview2/poll.rs index 68b6ce84c14f..54962aa3672c 100644 --- a/crates/wasi/src/preview2/poll.rs +++ b/crates/wasi/src/preview2/poll.rs @@ -1,13 +1,13 @@ -use crate::preview2::{bindings::io::poll, WasiView}; +use crate::preview2::{bindings::io::poll, Table, WasiView}; use anyhow::Result; use std::any::Any; -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use wasmtime::component::Resource; -pub type PollableFuture<'a> = Pin> + Send + 'a>>; +pub type PollableFuture<'a> = Pin + Send + 'a>>; pub type MakeFuture = for<'a> fn(&'a mut dyn Any) -> PollableFuture<'a>; pub type ClosureFuture = Box PollableFuture<'static> + Send + Sync + 'static>; @@ -17,15 +17,50 @@ pub type ClosureFuture = Box PollableFuture<'static> + Send + Sync + /// repeatedly check for readiness of a given condition, e.g. if a stream is readable /// or writable. So, rather than containing a Future, which can only become Ready once, a /// Pollable contains a way to create a Future in each call to `poll_list`. -pub enum Pollable { - /// Create a Future by calling a fn on another resource in the table. This - /// indirection means the created Future can use a mut borrow of another - /// resource in the Table (e.g. a stream) - TableEntry { index: u32, make_future: MakeFuture }, - /// Create a future by calling an owned, static closure. This is used for - /// pollables which do not share state with another resource in the Table - /// (e.g. a timer) - Closure(ClosureFuture), +pub struct Pollable { + index: u32, + make_future: MakeFuture, + remove_index_on_delete: Option Result<()>>, +} + +#[async_trait::async_trait] +pub trait Subscribe: Send + Sync + 'static { + async fn ready(&mut self); +} + +/// Creates a `pollable` resource which is susbcribed to the provided +/// `resource`. +/// +/// If `resource` is an owned resource then it will be deleted when the returned +/// resource is deleted. Otherwise the returned resource is considered a "child" +/// of the given `resource` which means that the given resource cannot be +/// deleted while the `pollable` is still alive. +pub fn subscribe(table: &mut Table, resource: Resource) -> Result> +where + T: Subscribe, +{ + fn make_future<'a, T>(stream: &'a mut dyn Any) -> PollableFuture<'a> + where + T: Subscribe, + { + stream.downcast_mut::().unwrap().ready() + } + + let pollable = Pollable { + index: resource.rep(), + remove_index_on_delete: if resource.owned() { + Some(|table, idx| { + let resource = Resource::::new_own(idx); + table.delete_resource(resource)?; + Ok(()) + }) + } else { + None + }, + make_future: make_future::, + }; + + Ok(table.push_child_resource(pollable, &resource)?) } #[async_trait::async_trait] @@ -36,88 +71,69 @@ impl poll::Host for T { let table = self.table_mut(); let mut table_futures: HashMap)> = HashMap::new(); - let mut closure_futures: Vec<(PollableFuture<'_>, Vec)> = Vec::new(); for (ix, p) in pollables.iter().enumerate() { let ix: u32 = ix.try_into()?; - match table.get_resource_mut(&p)? { - Pollable::Closure(f) => closure_futures.push((f(), vec![ix])), - Pollable::TableEntry { index, make_future } => match table_futures.entry(*index) { - Entry::Vacant(v) => { - v.insert((*make_future, vec![ix])); - } - Entry::Occupied(mut o) => { - let (_, v) = o.get_mut(); - v.push(ix); - } - }, - } + + let pollable = table.get_resource(p)?; + let (_, list) = table_futures + .entry(pollable.index) + .or_insert((pollable.make_future, Vec::new())); + list.push(ix); } + let mut futures: Vec<(PollableFuture<'_>, Vec)> = Vec::new(); for (entry, (make_future, readylist_indices)) in table.iter_entries(table_futures) { let entry = entry?; - closure_futures.push((make_future(entry), readylist_indices)); + futures.push((make_future(entry), readylist_indices)); } struct PollList<'a> { - elems: Vec<(PollableFuture<'a>, Vec)>, + futures: Vec<(PollableFuture<'a>, Vec)>, } impl<'a> Future for PollList<'a> { - type Output = Result>; + type Output = Vec; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut any_ready = false; let mut results = Vec::new(); - for (fut, readylist_indicies) in self.elems.iter_mut() { + for (fut, readylist_indicies) in self.futures.iter_mut() { match fut.as_mut().poll(cx) { - Poll::Ready(Ok(())) => { + Poll::Ready(()) => { results.extend_from_slice(readylist_indicies); any_ready = true; } - Poll::Ready(Err(e)) => { - return Poll::Ready(Err( - e.context(format!("poll_list {readylist_indicies:?}")) - )); - } Poll::Pending => {} } } if any_ready { - Poll::Ready(Ok(results)) + Poll::Ready(results) } else { Poll::Pending } } } - Ok(PollList { - elems: closure_futures, - } - .await?) + Ok(PollList { futures }.await) } async fn poll_one(&mut self, pollable: Resource) -> Result<()> { - use anyhow::Context; - let table = self.table_mut(); - let closure_future = match table.get_resource_mut(&pollable)? { - Pollable::Closure(f) => f(), - Pollable::TableEntry { index, make_future } => { - let index = *index; - let make_future = *make_future; - make_future(table.get_as_any_mut(index)?) - } - }; - - closure_future.await.context("poll_one") + let pollable = table.get_resource(&pollable)?; + let ready = (pollable.make_future)(table.get_as_any_mut(pollable.index)?); + ready.await; + Ok(()) } } #[async_trait::async_trait] impl crate::preview2::bindings::io::poll::HostPollable for T { fn drop(&mut self, pollable: Resource) -> Result<()> { - self.table_mut().delete_resource(pollable)?; + let pollable = self.table_mut().delete_resource(pollable)?; + if let Some(delete) = pollable.remove_index_on_delete { + delete(self.table_mut(), pollable.index)?; + } Ok(()) } } diff --git a/crates/wasi/src/preview2/stdio.rs b/crates/wasi/src/preview2/stdio.rs index b13aabf5215f..0d61adc2438a 100644 --- a/crates/wasi/src/preview2/stdio.rs +++ b/crates/wasi/src/preview2/stdio.rs @@ -318,7 +318,7 @@ mod test { let mut buffer = String::new(); loop { println!("child: waiting for stdin to be ready"); - stdin.ready().await.unwrap(); + stdin.ready().await; println!("child: reading input"); let (bytes, status) = stdin.read(1024).unwrap(); diff --git a/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs b/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs index bf933ad8398e..1d1c80aabace 100644 --- a/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs +++ b/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs @@ -23,6 +23,7 @@ //! This module is one that's likely to change over time though as new systems //! are encountered along with preexisting bugs. +use crate::preview2::poll::Subscribe; use crate::preview2::stdio::StdinStream; use crate::preview2::{HostInputStream, StreamState}; use anyhow::Error; @@ -145,8 +146,11 @@ impl HostInputStream for Stdin { } } } +} - async fn ready(&mut self) -> Result<(), Error> { +#[async_trait::async_trait] +impl Subscribe for Stdin { + async fn ready(&mut self) { let g = GlobalStdin::get(); // Scope the synchronous `state.lock()` to this block which does not @@ -161,12 +165,10 @@ impl HostInputStream for Stdin { g.read_completed.notified() } StdinState::ReadRequested => g.read_completed.notified(), - StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return Ok(()), + StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return, } }; notified.await; - - Ok(()) } } diff --git a/crates/wasi/src/preview2/stream.rs b/crates/wasi/src/preview2/stream.rs index 03a0ac1f0d46..7300b64073a4 100644 --- a/crates/wasi/src/preview2/stream.rs +++ b/crates/wasi/src/preview2/stream.rs @@ -1,5 +1,7 @@ use crate::preview2::filesystem::FileInputStream; +use crate::preview2::poll::Subscribe; use anyhow::Error; +use anyhow::Result; use bytes::Bytes; use std::fmt; @@ -44,7 +46,7 @@ impl StreamState { /// Host trait for implementing the `wasi:io/streams.input-stream` resource: A /// bytestream which can be read from. #[async_trait::async_trait] -pub trait HostInputStream: Send + Sync { +pub trait HostInputStream: Subscribe { /// Read bytes. On success, returns a pair holding the number of bytes /// read and a flag indicating whether the end of the stream was reached. /// Important: this read must be non-blocking! @@ -69,11 +71,6 @@ pub trait HostInputStream: Send + Sync { Ok((nread, state)) } - - /// Check for read readiness: this method blocks until the stream is ready - /// for reading. - /// Returning an error will trap execution. - async fn ready(&mut self) -> Result<(), Error>; } #[derive(Debug)] @@ -103,7 +100,7 @@ impl std::error::Error for OutputStreamError { /// Host trait for implementing the `wasi:io/streams.output-stream` resource: /// A bytestream which can be written to. #[async_trait::async_trait] -pub trait HostOutputStream: Send + Sync { +pub trait HostOutputStream: Subscribe { /// Write bytes after obtaining a permit to write those bytes /// Prior to calling [`write`](Self::write) /// the caller must call [`write_ready`](Self::write_ready), @@ -141,16 +138,17 @@ pub trait HostOutputStream: Send + Sync { /// - caller performed an illegal operation (e.g. wrote more bytes than were permitted) fn flush(&mut self) -> Result<(), OutputStreamError>; - /// Returns a future, which: - /// - when pending, indicates 0 bytes are permitted for writing - /// - when ready, returns a non-zero number of bytes permitted to write + /// Returns the number of bytes that are ready to be written to this stream. + /// + /// Zero bytes indicates that this stream is not currently ready for writing + /// and `ready()` must be awaited first. /// /// # Errors /// /// Returns an [OutputStreamError] if: /// - stream is closed /// - prior operation ([`write`](Self::write) or [`flush`](Self::flush)) failed - async fn write_ready(&mut self) -> Result; + fn check_write(&mut self) -> Result; /// Repeatedly write a byte to a stream. /// Important: this write must be non-blocking! @@ -163,6 +161,20 @@ pub trait HostOutputStream: Send + Sync { self.write(bs)?; Ok(()) } + + /// Simultaneously waits for this stream to be writable and then returns how + /// much may be written or the last error that happened. + async fn write_ready(&mut self) -> Result { + self.ready().await; + self.check_write() + } +} + +#[async_trait::async_trait] +impl Subscribe for Box { + async fn ready(&mut self) { + (**self).ready().await + } } pub enum InputStream { @@ -170,4 +182,15 @@ pub enum InputStream { File(FileInputStream), } +#[async_trait::async_trait] +impl Subscribe for InputStream { + async fn ready(&mut self) { + match self { + InputStream::Host(stream) => stream.ready().await, + // Files are always ready + InputStream::File(_) => {} + } + } +} + pub type OutputStream = Box; diff --git a/crates/wasi/src/preview2/tcp.rs b/crates/wasi/src/preview2/tcp.rs index 50b8a07c2f47..61bef55dd240 100644 --- a/crates/wasi/src/preview2/tcp.rs +++ b/crates/wasi/src/preview2/tcp.rs @@ -1,11 +1,15 @@ use super::{HostInputStream, HostOutputStream, OutputStreamError}; +use crate::preview2::poll::Subscribe; use crate::preview2::stream::{InputStream, OutputStream}; use crate::preview2::{with_ambient_tokio_runtime, AbortOnDropJoinHandle, StreamState}; +use anyhow::{Error, Result}; use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt}; use cap_std::net::TcpListener; use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; use std::io; +use std::mem; use std::sync::Arc; +use tokio::io::Interest; /// The state of a TCP socket. /// @@ -102,13 +106,15 @@ impl HostInputStream for TcpReadStream { buf.truncate(n); Ok((buf.freeze(), self.stream_state())) } +} - async fn ready(&mut self) -> Result<(), anyhow::Error> { +#[async_trait::async_trait] +impl Subscribe for TcpReadStream { + async fn ready(&mut self) { if self.closed { - return Ok(()); + return; } - self.stream.readable().await?; - Ok(()) + self.stream.readable().await.unwrap(); } } @@ -116,53 +122,60 @@ const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024; pub(crate) struct TcpWriteStream { stream: Arc, - write_handle: Option>>, + last_write: LastWrite, +} + +enum LastWrite { + Waiting(AbortOnDropJoinHandle>), + Error(Error), + Done, } impl TcpWriteStream { pub(crate) fn new(stream: Arc) -> Self { Self { stream, - write_handle: None, + last_write: LastWrite::Done, } } /// Write `bytes` in a background task, remembering the task handle for use in a future call to /// `write_ready` fn background_write(&mut self, mut bytes: bytes::Bytes) { - assert!(self.write_handle.is_none()); + assert!(matches!(self.last_write, LastWrite::Done)); let stream = self.stream.clone(); - self.write_handle - .replace(crate::preview2::spawn(async move { - // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream - // primitive try_write, which goes directly to attempt a write with mio. This has - // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream - // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need - // to flush. - while !bytes.is_empty() { - stream.writable().await?; - match stream.try_write(&bytes) { - Ok(n) => { - let _ = bytes.split_to(n); - } - Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue, - Err(e) => return Err(e.into()), + self.last_write = LastWrite::Waiting(crate::preview2::spawn(async move { + // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream + // primitive try_write, which goes directly to attempt a write with mio. This has + // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream + // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need + // to flush. + while !bytes.is_empty() { + stream.writable().await?; + match stream.try_write(&bytes) { + Ok(n) => { + let _ = bytes.split_to(n); } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue, + Err(e) => return Err(e.into()), } + } - Ok(()) - })); + Ok(()) + })); } } -#[async_trait::async_trait] impl HostOutputStream for TcpWriteStream { fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), OutputStreamError> { - if self.write_handle.is_some() { - return Err(OutputStreamError::Trap(anyhow::anyhow!( - "unpermitted: cannot write while background write ongoing" - ))); + match self.last_write { + LastWrite::Done => {} + LastWrite::Waiting(_) | LastWrite::Error(_) => { + return Err(OutputStreamError::Trap(anyhow::anyhow!( + "unpermitted: must call check_write first" + ))); + } } while !bytes.is_empty() { match self.stream.try_write(&bytes) { @@ -192,26 +205,40 @@ impl HostOutputStream for TcpWriteStream { Ok(()) } - async fn write_ready(&mut self) -> Result { - if let Some(handle) = &mut self.write_handle { - handle - .await - .map_err(|e| OutputStreamError::LastOperationFailed(e.into()))?; - - // Only clear out the write handle once the task has exited, to ensure that - // `write_ready` remains cancel-safe. - self.write_handle = None; + fn check_write(&mut self) -> Result { + match mem::replace(&mut self.last_write, LastWrite::Done) { + LastWrite::Waiting(task) => { + self.last_write = LastWrite::Waiting(task); + return Ok(0); + } + LastWrite::Done => {} + LastWrite::Error(e) => return Err(OutputStreamError::LastOperationFailed(e.into())), } - self.stream - .writable() - .await - .map_err(|e| OutputStreamError::LastOperationFailed(e.into()))?; - + let writable = self.stream.writable(); + futures::pin_mut!(writable); + if super::poll_noop(writable).is_none() { + return Ok(0); + } Ok(SOCKET_READY_SIZE) } } +#[async_trait::async_trait] +impl Subscribe for TcpWriteStream { + async fn ready(&mut self) { + if let LastWrite::Waiting(task) = &mut self.last_write { + self.last_write = match task.await { + Ok(()) => LastWrite::Done, + Err(e) => LastWrite::Error(e), + }; + } + if let LastWrite::Done = self.last_write { + self.stream.writable().await.unwrap(); + } + } +} + impl TcpSocket { /// Create a new socket in the given family. pub fn new(family: AddressFamily) -> io::Result { @@ -251,3 +278,20 @@ impl TcpSocket { (InputStream::Host(input), output) } } + +#[async_trait::async_trait] +impl Subscribe for TcpSocket { + async fn ready(&mut self) { + // Some states are ready immediately. + match self.tcp_state { + TcpState::BindStarted | TcpState::ListenStarted | TcpState::ConnectReady => return, + _ => {} + } + + // FIXME: Add `Interest::ERROR` when we update to tokio 1.32. + self.inner + .ready(Interest::READABLE | Interest::WRITABLE) + .await + .unwrap(); + } +} diff --git a/crates/wasi/src/preview2/write_stream.rs b/crates/wasi/src/preview2/write_stream.rs index bf154aec0720..726d0d2674bb 100644 --- a/crates/wasi/src/preview2/write_stream.rs +++ b/crates/wasi/src/preview2/write_stream.rs @@ -1,4 +1,4 @@ -use crate::preview2::{HostOutputStream, OutputStreamError}; +use crate::preview2::{HostOutputStream, OutputStreamError, Subscribe}; use anyhow::anyhow; use bytes::Bytes; use std::sync::{Arc, Mutex}; @@ -35,11 +35,6 @@ enum Job { Write(Bytes), } -enum WriteStatus<'a> { - Done(Result), - Pending(tokio::sync::futures::Notified<'a>), -} - impl Worker { fn new(write_budget: usize) -> Self { Self { @@ -54,17 +49,31 @@ impl Worker { write_ready_changed: tokio::sync::Notify::new(), } } - fn check_write(&self) -> WriteStatus<'_> { + async fn ready(&self) { + loop { + { + let state = self.state(); + if state.error.is_some() + || !state.alive + || (!state.flush_pending && state.write_budget > 0) + { + return; + } + } + self.write_ready_changed.notified().await; + } + } + fn check_write(&self) -> Result { let mut state = self.state(); if let Err(e) = state.check_error() { - return WriteStatus::Done(Err(e)); + return Err(e); } if state.flush_pending || state.write_budget == 0 { - return WriteStatus::Pending(self.write_ready_changed.notified()); + return Ok(0); } - WriteStatus::Done(Ok(state.write_budget)) + Ok(state.write_budget) } fn state(&self) -> std::sync::MutexGuard { self.state.lock().unwrap() @@ -154,7 +163,6 @@ impl AsyncWriteStream { } } -#[async_trait::async_trait] impl HostOutputStream for AsyncWriteStream { fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { let mut state = self.worker.state(); @@ -185,12 +193,13 @@ impl HostOutputStream for AsyncWriteStream { Ok(()) } - async fn write_ready(&mut self) -> Result { - loop { - match self.worker.check_write() { - WriteStatus::Done(r) => return r, - WriteStatus::Pending(notifier) => notifier.await, - } - } + fn check_write(&mut self) -> Result { + self.worker.check_write() + } +} +#[async_trait::async_trait] +impl Subscribe for AsyncWriteStream { + async fn ready(&mut self) { + self.worker.ready().await; } }