Skip to content

Commit

Permalink
wasi: Fix a few issues around stdin (bytecodealliance#7063)
Browse files Browse the repository at this point in the history
* wasi: Fix a few issues around stdin

This commit is intended to address bytecodealliance#6986 and some other issues related
to stdin and reading it, notably:

* Previously once EOF was reached the `closed` flag was mistakenly not
  set.
* Previously data would be infinitely buffered regardless of how fast
  the guest program would consume it.
* Previously stdin would be immediately ready by Wasmtime regardless of
  whether the guest wanted to read stdin or not.
* The host-side preview1-to-preview2 adapter didn't perform a blocking
  read meaning that it never blocked.

These issues are addressed by refactoring the code in question.
Note that this is similar to the logic of `AsyncReadStream` somewhat but
that type is not appropriate in this context due to the singleton nature
of stdin meaning that the per-stream helper task and per-stream buffer
of `AsyncReadStream` are not appropriate.

Closees bytecodealliance#6986

* Increase slop size for windows
  • Loading branch information
alexcrichton authored and eduardomourar committed Sep 22, 2023
1 parent fe6563a commit 4bced80
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 98 deletions.
5 changes: 2 additions & 3 deletions crates/wasi/src/preview2/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ pub struct AsyncReadStream {
state: StreamState,
buffer: Option<Result<Bytes, std::io::Error>>,
receiver: mpsc::Receiver<Result<(Bytes, StreamState), std::io::Error>>,
#[allow(unused)] // just used to implement unix stdin
pub(crate) join_handle: crate::preview2::AbortOnDropJoinHandle<()>,
_join_handle: crate::preview2::AbortOnDropJoinHandle<()>,
}

impl AsyncReadStream {
Expand Down Expand Up @@ -150,7 +149,7 @@ impl AsyncReadStream {
state: StreamState::Open,
buffer: None,
receiver,
join_handle,
_join_handle: join_handle,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/wasi/src/preview2/preview1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ impl<
return Ok(0);
};
let (read, state) = stream_res(
streams::Host::read(
streams::Host::blocking_read(
self,
input_stream,
buf.len().try_into().unwrap_or(u64::MAX),
Expand Down
213 changes: 120 additions & 93 deletions crates/wasi/src/preview2/stdio/worker_thread_stdin.rs
Original file line number Diff line number Diff line change
@@ -1,91 +1,103 @@
//! Handling for standard in using a worker task.
//!
//! Standard input is a global singleton resource for the entire program which
//! needs special care. Currently this implementation adheres to a few
//! constraints which make this nontrivial to implement.
//!
//! * Any number of guest wasm programs can read stdin. While this doesn't make
//! a ton of sense semantically they shouldn't block forever. Instead it's a
//! race to see who actually reads which parts of stdin.
//!
//! * Data from stdin isn't actually read unless requested. This is done to try
//! to be a good neighbor to others running in the process. Under the
//! assumption that most programs have one "thing" which reads stdin the
//! actual consumption of bytes is delayed until the wasm guest is dynamically
//! chosen to be that "thing". Before that data from stdin is not consumed to
//! avoid taking it from other components in the process.
//!
//! * Tokio's documentation indicates that "interactive stdin" is best done with
//! a helper thread to avoid blocking shutdown of the event loop. That's
//! respected here where all stdin reading happens on a blocking helper thread
//! that, at this time, is never shut down.
//!
//! This module is one that's likely to change over time though as new systems
//! are encountered along with preexisting bugs.
use crate::preview2::{HostInputStream, StreamState};
use anyhow::Error;
use bytes::{Bytes, BytesMut};
use std::io::Read;
use std::sync::Arc;
use tokio::sync::watch;

// wasmtime cant use std::sync::OnceLock yet because of a llvm regression in
// 1.70. when 1.71 is released, we can switch to using std here.
use once_cell::sync::OnceCell as OnceLock;

use std::sync::Mutex;
use std::mem;
use std::sync::{Condvar, Mutex, OnceLock};
use tokio::sync::Notify;

#[derive(Default)]
struct GlobalStdin {
// Worker thread uses this to notify of new events. Ready checks use this
// to create a new Receiver via .subscribe(). The newly created receiver
// will only wait for events created after the call to subscribe().
tx: Arc<watch::Sender<()>>,
// Worker thread and receivers share this state to get bytes read off
// stdin, or the error/closed state.
state: Arc<Mutex<StdinState>>,
state: Mutex<StdinState>,
read_requested: Condvar,
read_completed: Notify,
}

#[derive(Debug)]
struct StdinState {
// Bytes read off stdin.
buffer: BytesMut,
// Error read off stdin, if any.
error: Option<std::io::Error>,
// If an error has occured in the past, we consider the stream closed.
closed: bool,
#[derive(Default, Debug)]
enum StdinState {
#[default]
ReadNotRequested,
ReadRequested,
Data(BytesMut),
Error(std::io::Error),
Closed,
}

static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
impl GlobalStdin {
fn get() -> &'static GlobalStdin {
static STDIN: OnceLock<GlobalStdin> = OnceLock::new();
STDIN.get_or_init(|| create())
}
}

fn create() -> GlobalStdin {
let (tx, _rx) = watch::channel(());
let tx = Arc::new(tx);

let state = Arc::new(Mutex::new(StdinState {
buffer: BytesMut::new(),
error: None,
closed: false,
}));

let ret = GlobalStdin {
state: state.clone(),
tx: tx.clone(),
};

std::thread::spawn(move || loop {
let mut bytes = BytesMut::zeroed(1024);
match std::io::stdin().lock().read(&mut bytes) {
// Reading `0` indicates that stdin has reached EOF, so we break
// the loop to allow the thread to exit.
Ok(0) => break,

Ok(nbytes) => {
// Append to the buffer:
bytes.truncate(nbytes);
let mut locked = state.lock().unwrap();
locked.buffer.extend_from_slice(&bytes);
}
Err(e) => {
// Set the error, and mark the stream as closed:
let mut locked = state.lock().unwrap();
if locked.error.is_none() {
locked.error = Some(e)
std::thread::spawn(|| {
let state = GlobalStdin::get();
loop {
// Wait for a read to be requested, but don't hold the lock across
// the blocking read.
let mut lock = state.state.lock().unwrap();
lock = state
.read_requested
.wait_while(lock, |state| !matches!(state, StdinState::ReadRequested))
.unwrap();
drop(lock);

let mut bytes = BytesMut::zeroed(1024);
let (new_state, done) = match std::io::stdin().read(&mut bytes) {
Ok(0) => (StdinState::Closed, true),
Ok(nbytes) => {
bytes.truncate(nbytes);
(StdinState::Data(bytes), false)
}
locked.closed = true;
Err(e) => (StdinState::Error(e), true),
};

// After the blocking read completes the state should not have been
// tampered with.
debug_assert!(matches!(
*state.state.lock().unwrap(),
StdinState::ReadRequested
));
*state.state.lock().unwrap() = new_state;
state.read_completed.notify_waiters();
if done {
break;
}
}
// Receivers may or may not exist - fine if they dont, new
// ones will be created with subscribe()
let _ = tx.send(());
});
ret

GlobalStdin::default()
}

/// Only public interface is the [`HostInputStream`] impl.
#[derive(Clone)]
pub struct Stdin;
impl Stdin {
// Private! Only required internally.
fn get_global() -> &'static GlobalStdin {
STDIN.get_or_init(|| create())
}
}

pub fn stdin() -> Stdin {
Stdin
Expand All @@ -100,40 +112,55 @@ impl is_terminal::IsTerminal for Stdin {
#[async_trait::async_trait]
impl HostInputStream for Stdin {
fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> {
let g = Stdin::get_global();
let g = GlobalStdin::get();
let mut locked = g.state.lock().unwrap();

if let Some(e) = locked.error.take() {
return Err(e.into());
match mem::replace(&mut *locked, StdinState::ReadRequested) {
StdinState::ReadNotRequested => {
g.read_requested.notify_one();
Ok((Bytes::new(), StreamState::Open))
}
StdinState::ReadRequested => Ok((Bytes::new(), StreamState::Open)),
StdinState::Data(mut data) => {
let size = data.len().min(size);
let bytes = data.split_to(size);
*locked = if data.is_empty() {
StdinState::ReadNotRequested
} else {
StdinState::Data(data)
};
Ok((bytes.freeze(), StreamState::Open))
}
StdinState::Error(e) => {
*locked = StdinState::Closed;
return Err(e.into());
}
StdinState::Closed => {
*locked = StdinState::Closed;
Ok((Bytes::new(), StreamState::Closed))
}
}
let size = locked.buffer.len().min(size);
let bytes = locked.buffer.split_to(size);
let state = if locked.buffer.is_empty() && locked.closed {
StreamState::Closed
} else {
StreamState::Open
};
Ok((bytes.freeze(), state))
}

async fn ready(&mut self) -> Result<(), Error> {
let g = Stdin::get_global();

// Block makes sure we dont hold the mutex across the await:
let mut rx = {
let locked = g.state.lock().unwrap();
// read() will only return (empty, open) when the buffer is empty,
// AND there is no error AND the stream is still open:
if !locked.buffer.is_empty() || locked.error.is_some() || locked.closed {
return Ok(());
let g = GlobalStdin::get();

// Scope the synchronous `state.lock()` to this block which does not
// `.await` inside of it.
let notified = {
let mut locked = g.state.lock().unwrap();
match *locked {
// If a read isn't requested yet
StdinState::ReadNotRequested => {
g.read_requested.notify_one();
*locked = StdinState::ReadRequested;
g.read_completed.notified()
}
StdinState::ReadRequested => g.read_completed.notified(),
StdinState::Data(_) | StdinState::Closed | StdinState::Error(_) => return Ok(()),
}
// Sender will take the mutex before updating the state of
// subscribe, so this ensures we will only await for any stdin
// events that are recorded after we drop the mutex:
g.tx.subscribe()
};

rx.changed().await.expect("impossible for sender to drop");
notified.await;

Ok(())
}
Expand Down
86 changes: 85 additions & 1 deletion tests/all/cli_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::{bail, Result};
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::process::{Command, Output};
use std::process::{Command, Output, Stdio};
use tempfile::{NamedTempFile, TempDir};

// Run the wasmtime CLI with the provided args and return the `Output`.
Expand Down Expand Up @@ -932,3 +932,87 @@ fn option_group_boolean_parsing() -> Result<()> {
])?;
Ok(())
}

#[test]
fn preview2_stdin() -> Result<()> {
let test = "tests/all/cli_tests/count-stdin.wat";
let cmd = || -> Result<_> {
let mut cmd = get_wasmtime_command()?;
cmd.arg("--invoke=count").arg("-Spreview2").arg(test);
Ok(cmd)
};

// read empty pipe is ok
let output = cmd()?.output()?;
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout), "0\n");

// read itself is ok
let file = File::open(test)?;
let size = file.metadata()?.len();
let output = cmd()?.stdin(File::open(test)?).output()?;
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout), format!("{size}\n"));

// read piped input ok is ok
let mut child = cmd()?
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut stdin = child.stdin.take().unwrap();
std::thread::spawn(move || {
stdin.write_all(b"hello").unwrap();
});
let output = child.wait_with_output()?;
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout), "5\n");

let count_up_to = |n: usize| -> Result<_> {
let mut child = get_wasmtime_command()?
.arg("--invoke=count-up-to")
.arg("-Spreview2")
.arg(test)
.arg(n.to_string())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let mut stdin = child.stdin.take().unwrap();
let t = std::thread::spawn(move || {
let mut written = 0;
let bytes = [0; 64 * 1024];
loop {
written += match stdin.write(&bytes) {
Ok(n) => n,
Err(_) => break written,
};
}
});
let output = child.wait_with_output()?;
assert!(output.status.success());
let written = t.join().unwrap();
let read = String::from_utf8_lossy(&output.stdout)
.trim()
.parse::<usize>()
.unwrap();
// The test reads in 1000 byte chunks so make sure that it doesn't read
// more than 1000 bytes than requested.
assert!(read < n + 1000, "test read too much {read}");
Ok(written)
};

// wasmtime shouldn't eat information that the guest never actually tried to
// read.
//
// NB: this may be a bit flaky. Exactly how much we wrote in the above
// helper thread depends on how much the OS buffers for us. For now give
// some some slop and assume that OSes are unlikely to buffer more than
// that.
let slop = 256 * 1024;
for amt in [0, 100, 100_000] {
let written = count_up_to(amt)?;
assert!(written < slop + amt, "wrote too much {written}");
}
Ok(())
}
Loading

0 comments on commit 4bced80

Please sign in to comment.