Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor how subscribe works in WASI #7130

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions crates/test-programs/wasi-tests/src/bin/poll_oneoff_files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,29 @@ unsafe fn test_fd_readwrite(readable_fd: wasi::Fd, writable_fd: wasi::Fd, error_
];
let out = poll_oneoff_with_retry(&r#in).unwrap();
assert_eq!(out.len(), 2, "should return 2 events, got: {:?}", out);

let (read, write) = if out[0].userdata == 1 {
(&out[0], &out[1])
} else {
(&out[1], &out[0])
};
assert_eq!(
out[0].userdata, 1,
read.userdata, 1,
"the event.userdata should contain fd userdata specified by the user"
);
assert_errno!(out[0].error, error_code);
assert_errno!(read.error, error_code);
assert_eq!(
out[0].type_,
read.type_,
wasi::EVENTTYPE_FD_READ,
"the event.type_ should equal FD_READ"
);
assert_eq!(
out[1].userdata, 2,
write.userdata, 2,
"the event.userdata should contain fd userdata specified by the user"
);
assert_errno!(out[1].error, error_code);
assert_errno!(write.error, error_code);
assert_eq!(
out[1].type_,
write.type_,
wasi::EVENTTYPE_FD_WRITE,
"the event.type_ should equal FD_WRITE"
);
Expand Down
61 changes: 36 additions & 25 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::preview2::{
self, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, OutputStreamError,
StreamRuntimeError, StreamState,
StreamRuntimeError, StreamState, Subscribe,
};

pub type HyperIncomingBody = BoxBody<Bytes, anyhow::Error>;
Expand Down Expand Up @@ -189,14 +189,17 @@ impl HostInputStream for HostIncomingBodyStream {
}
}
}
}

async fn ready(&mut self) -> anyhow::Result<()> {
#[async_trait::async_trait]
impl Subscribe for HostIncomingBodyStream {
async fn ready(&mut self) {
if !self.buffer.is_empty() {
return Ok(());
return;
}

if !self.open {
return Ok(());
return;
}

match self.receiver.recv().await {
Expand All @@ -209,8 +212,6 @@ impl HostInputStream for HostIncomingBodyStream {

None => self.open = false,
}

Ok(())
}
}

Expand All @@ -224,8 +225,9 @@ pub enum HostFutureTrailersState {
Done(Result<FieldMap, types::Error>),
}

impl HostFutureTrailers {
pub async fn ready(&mut self) -> anyhow::Result<()> {
#[async_trait::async_trait]
impl Subscribe for HostFutureTrailers {
async fn ready(&mut self) {
if let HostFutureTrailersState::Waiting(rx) = &mut self.state {
let result = match rx.await {
Ok(Ok(headers)) => Ok(FieldMap::from(headers)),
Expand All @@ -236,7 +238,6 @@ impl HostFutureTrailers {
};
self.state = HostFutureTrailersState::Done(result);
}
Ok(())
}
}

Expand Down Expand Up @@ -353,11 +354,6 @@ enum Job {
Write(Bytes),
}

enum WriteStatus<'a> {
Done(Result<usize, OutputStreamError>),
Pending(tokio::sync::futures::Notified<'a>),
}

impl Worker {
fn new(write_budget: usize) -> Self {
Self {
Expand All @@ -372,17 +368,31 @@ impl Worker {
write_ready_changed: tokio::sync::Notify::new(),
}
}
fn check_write(&self) -> WriteStatus<'_> {
async fn ready(&self) {
loop {
{
let state = self.state();
if state.error.is_some()
|| !state.alive
|| (!state.flush_pending && state.write_budget > 0)
{
return;
}
}
self.write_ready_changed.notified().await;
}
}
fn check_write(&self) -> Result<usize, OutputStreamError> {
let mut state = self.state();
if let Err(e) = state.check_error() {
return WriteStatus::Done(Err(e));
return Err(e);
}

if state.flush_pending || state.write_budget == 0 {
return WriteStatus::Pending(self.write_ready_changed.notified());
return Ok(0);
}

WriteStatus::Done(Ok(state.write_budget))
Ok(state.write_budget)
}
fn state(&self) -> std::sync::MutexGuard<WorkerState> {
self.state.lock().unwrap()
Expand Down Expand Up @@ -496,12 +506,13 @@ impl HostOutputStream for BodyWriteStream {
Ok(())
}

async fn write_ready(&mut self) -> Result<usize, OutputStreamError> {
loop {
match self.worker.check_write() {
WriteStatus::Done(r) => return r,
WriteStatus::Pending(notifier) => notifier.await,
}
}
fn check_write(&mut self) -> Result<usize, OutputStreamError> {
self.worker.check_write()
}
}
#[async_trait::async_trait]
impl Subscribe for BodyWriteStream {
async fn ready(&mut self) {
self.worker.ready().await
}
}
24 changes: 6 additions & 18 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ use crate::{
},
};
use std::any::Any;
use std::pin::Pin;
use std::task;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Table, TableError};
use wasmtime_wasi::preview2::{AbortOnDropJoinHandle, Subscribe, Table, TableError};

/// Capture the state necessary for use in the wasi-http API implementation.
pub struct WasiHttpCtx;
Expand Down Expand Up @@ -167,21 +165,11 @@ impl HostFutureIncomingResponse {
}
}

impl std::future::Future for HostFutureIncomingResponse {
type Output = anyhow::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
let s = self.get_mut();
match s {
Self::Pending(ref mut handle) => match Pin::new(handle).poll(cx) {
task::Poll::Pending => task::Poll::Pending,
task::Poll::Ready(r) => {
*s = Self::Ready(r);
task::Poll::Ready(Ok(()))
}
},

Self::Consumed | Self::Ready(_) => task::Poll::Ready(Ok(())),
#[async_trait::async_trait]
impl Subscribe for HostFutureIncomingResponse {
async fn ready(&mut self) {
if let Self::Pending(handle) = self {
*self = Self::Ready(handle.await);
}
}
}
Expand Down
39 changes: 9 additions & 30 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::any::Any;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2::{
bindings::io::streams::{InputStream, OutputStream},
Pollable, PollableFuture,
Pollable,
};

impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
Expand Down Expand Up @@ -352,19 +352,10 @@ impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
&mut self,
index: FutureTrailers,
) -> wasmtime::Result<Resource<Pollable>> {
// Eagerly force errors about the validity of the index.
let _ = self.table().get_future_trailers(index)?;

fn make_future(elem: &mut dyn Any) -> PollableFuture {
Box::pin(elem.downcast_mut::<HostFutureTrailers>().unwrap().ready())
}

// FIXME: this should use `push_child_resource`
let id = self
.table()
.push_resource(Pollable::TableEntry { index, make_future })?;

Ok(id)
wasmtime_wasi::preview2::subscribe(
self.table(),
Resource::<HostFutureTrailers>::new_borrow(index),
)
}

fn future_trailers_get(
Expand Down Expand Up @@ -480,22 +471,10 @@ impl<T: WasiHttpView> crate::bindings::http::types::Host for T {
&mut self,
id: FutureIncomingResponse,
) -> wasmtime::Result<Resource<Pollable>> {
let _ = self.table().get_future_incoming_response(id)?;

fn make_future<'a>(elem: &'a mut dyn Any) -> PollableFuture<'a> {
Box::pin(
elem.downcast_mut::<HostFutureIncomingResponse>()
.expect("parent resource is HostFutureIncomingResponse"),
)
}

// FIXME: this should use `push_child_resource`
let pollable = self.table().push_resource(Pollable::TableEntry {
index: id,
make_future,
})?;

Ok(pollable)
wasmtime_wasi::preview2::subscribe(
self.table(),
Resource::<HostFutureIncomingResponse>::new_borrow(id),
)
}

fn outgoing_body_write(
Expand Down
Loading