diff --git a/crates/wasi/src/preview2/preview1.rs b/crates/wasi/src/preview2/preview1.rs index 282f50ae75f6..3b7730a93274 100644 --- a/crates/wasi/src/preview2/preview1.rs +++ b/crates/wasi/src/preview2/preview1.rs @@ -83,52 +83,68 @@ impl BlockingMode { output_stream: Resource, mut bytes: &[u8], ) -> StreamResult { - use streams::HostOutputStream as Streams; + use streams::HostOutputStream; + + async fn nonblocking_write_blocking_flush( + host: &mut (impl streams::Host + poll::Host), + output_stream: Resource, + bytes: &[u8], + ) -> StreamResult { + let n = match HostOutputStream::check_write(host, output_stream.borrowed()) { + Ok(n) => n, + Err(StreamError::Closed) => 0, + Err(e) => Err(e)?, + }; + + let len = bytes.len().min(n as usize); + if len == 0 { + return Ok(0); + } + + match HostOutputStream::write(host, output_stream.borrowed(), bytes[..len].to_vec()) { + Ok(()) => {} + Err(StreamError::Closed) => return Ok(0), + Err(e) => Err(e)?, + } + + match HostOutputStream::blocking_flush(host, output_stream.borrowed()).await { + Ok(()) => {} + Err(StreamError::Closed) => return Ok(0), + Err(e) => Err(e)?, + }; + + Ok(len) + } match self { BlockingMode::Blocking => { - let total = bytes.len(); + let mut total = 0; + let pollable = HostOutputStream::subscribe(host, output_stream.borrowed()) + .map_err(StreamError::Trap)?; while !bytes.is_empty() { - // NOTE: blocking_write_and_flush takes at most one 4k buffer. - let len = bytes.len().min(4096); - let (chunk, rest) = bytes.split_at(len); - bytes = rest; - - Streams::blocking_write_and_flush( + poll::Host::poll_one(host, pollable.borrowed()) + .await + .map_err(StreamError::Trap)?; + match nonblocking_write_blocking_flush( host, output_stream.borrowed(), - Vec::from(chunk), + &bytes[total..], ) .await? + { + 0 => return Ok(total), + n => { + total += n; + let (_, rest) = bytes.split_at(n); + bytes = rest; + } + } } - + poll::HostPollable::drop(host, pollable).map_err(StreamError::Trap)?; Ok(total) } BlockingMode::NonBlocking => { - let n = match Streams::check_write(host, output_stream.borrowed()) { - Ok(n) => n, - Err(StreamError::Closed) => 0, - Err(e) => Err(e)?, - }; - - let len = bytes.len().min(n as usize); - if len == 0 { - return Ok(0); - } - - match Streams::write(host, output_stream.borrowed(), bytes[..len].to_vec()) { - Ok(()) => {} - Err(StreamError::Closed) => return Ok(0), - Err(e) => Err(e)?, - } - - match Streams::blocking_flush(host, output_stream.borrowed()).await { - Ok(()) => {} - Err(StreamError::Closed) => return Ok(0), - Err(e) => Err(e)?, - }; - - Ok(len) + nonblocking_write_blocking_flush(host, output_stream, bytes).await } } }