diff --git a/Cargo.lock b/Cargo.lock index c3582a8a19c0..f9dac1c99e2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3138,6 +3138,7 @@ dependencies = [ "cap-rand", "cap-std", "io-extras", + "log", "rustix", "thiserror", "tracing", diff --git a/crates/wasi-common/Cargo.toml b/crates/wasi-common/Cargo.toml index 065a03b52638..3e77c5c8488f 100644 --- a/crates/wasi-common/Cargo.toml +++ b/crates/wasi-common/Cargo.toml @@ -26,6 +26,7 @@ tracing = { workspace = true } cap-std = { workspace = true } cap-rand = { workspace = true } bitflags = { workspace = true } +log = { workspace = true } [target.'cfg(unix)'.dependencies] rustix = { workspace = true, features = ["fs"] } diff --git a/crates/wasi-common/cap-std-sync/src/file.rs b/crates/wasi-common/cap-std-sync/src/file.rs index c54a5ead7145..c184486a7f24 100644 --- a/crates/wasi-common/cap-std-sync/src/file.rs +++ b/crates/wasi-common/cap-std-sync/src/file.rs @@ -93,7 +93,7 @@ impl WasiFile for File { let fdflags = get_fd_flags(&*file)?; Ok(fdflags) } - async fn set_fdflags(&self, fdflags: FdFlags) -> Result<(), Error> { + async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> { if fdflags.intersects( wasi_common::file::FdFlags::DSYNC | wasi_common::file::FdFlags::SYNC @@ -101,7 +101,7 @@ impl WasiFile for File { ) { return Err(Error::invalid_argument().context("cannot set DSYNC, SYNC, or RSYNC flag")); } - let mut file = self.0.write().unwrap(); + let file = self.0.get_mut().unwrap(); let set_fd_flags = (*file).new_set_fd_flags(to_sysif_fdflags(fdflags))?; (*file).set_fd_flags(set_fd_flags)?; Ok(()) diff --git a/crates/wasi-common/cap-std-sync/src/net.rs b/crates/wasi-common/cap-std-sync/src/net.rs index b46d8d7dafaf..921622d68d96 100644 --- a/crates/wasi-common/cap-std-sync/src/net.rs +++ b/crates/wasi-common/cap-std-sync/src/net.rs @@ -98,7 +98,7 @@ macro_rules! wasi_listen_write_impl { } async fn sock_accept(&self, fdflags: FdFlags) -> Result, Error> { let (stream, _) = self.0.accept()?; - let stream = <$stream>::from_cap_std(stream); + let mut stream = <$stream>::from_cap_std(stream); stream.set_fdflags(fdflags).await?; Ok(Box::new(stream)) } @@ -110,7 +110,7 @@ macro_rules! wasi_listen_write_impl { let fdflags = get_fd_flags(&self.0)?; Ok(fdflags) } - async fn set_fdflags(&self, fdflags: FdFlags) -> Result<(), Error> { + async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> { if fdflags == wasi_common::file::FdFlags::NONBLOCK { self.0.set_nonblocking(true)?; } else if fdflags.is_empty() { @@ -197,7 +197,7 @@ macro_rules! wasi_stream_write_impl { let fdflags = get_fd_flags(&self.0)?; Ok(fdflags) } - async fn set_fdflags(&self, fdflags: FdFlags) -> Result<(), Error> { + async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> { if fdflags == wasi_common::file::FdFlags::NONBLOCK { self.0.set_nonblocking(true)?; } else if fdflags.is_empty() { diff --git a/crates/wasi-common/src/ctx.rs b/crates/wasi-common/src/ctx.rs index 1a115e0d8d43..f6b08e1cfef6 100644 --- a/crates/wasi-common/src/ctx.rs +++ b/crates/wasi-common/src/ctx.rs @@ -2,17 +2,26 @@ use crate::clocks::WasiClocks; use crate::dir::{DirCaps, DirEntry, WasiDir}; use crate::file::{FileCaps, FileEntry, WasiFile}; use crate::sched::WasiSched; -use crate::string_array::{StringArray, StringArrayError}; +use crate::string_array::StringArray; use crate::table::Table; -use crate::Error; +use crate::{Error, StringArrayError}; use cap_rand::RngCore; +use std::ops::Deref; use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; -pub struct WasiCtx { +/// An `Arc`-wrapper around the wasi-common context to allow mutable access to +/// the file descriptor table. This wrapper is only necessary due to the +/// signature of `fd_fdstat_set_flags`; if that changes, there are a variety of +/// improvements that can be made (TODO: +/// https://github.com/bytecodealliance/wasmtime/issues/5643). +#[derive(Clone)] +pub struct WasiCtx(Arc); + +pub struct WasiCtxInner { pub args: StringArray, pub env: StringArray, - pub random: Box, + pub random: Mutex>, pub clocks: WasiClocks, pub sched: Box, pub table: Table, @@ -25,14 +34,14 @@ impl WasiCtx { sched: Box, table: Table, ) -> Self { - let s = WasiCtx { + let s = WasiCtx(Arc::new(WasiCtxInner { args: StringArray::new(), env: StringArray::new(), - random, + random: Mutex::new(random), clocks, sched, table, - }; + })); s.set_stdin(Box::new(crate::pipe::ReadPipe::new(std::io::empty()))); s.set_stdout(Box::new(crate::pipe::WritePipe::new(std::io::sink()))); s.set_stderr(Box::new(crate::pipe::WritePipe::new(std::io::sink()))); @@ -77,12 +86,22 @@ impl WasiCtx { &self.table } + pub fn table_mut(&mut self) -> Option<&mut Table> { + Arc::get_mut(&mut self.0).map(|c| &mut c.table) + } + pub fn push_arg(&mut self, arg: &str) -> Result<(), StringArrayError> { - self.args.push(arg.to_owned()) + let s = Arc::get_mut(&mut self.0).expect( + "`push_arg` should only be used during initialization before the context is cloned", + ); + s.args.push(arg.to_owned()) } pub fn push_env(&mut self, var: &str, value: &str) -> Result<(), StringArrayError> { - self.env.push(format!("{}={}", var, value))?; + let s = Arc::get_mut(&mut self.0).expect( + "`push_env` should only be used during initialization before the context is cloned", + ); + s.env.push(format!("{}={}", var, value))?; Ok(()) } @@ -130,3 +149,10 @@ impl WasiCtx { Ok(()) } } + +impl Deref for WasiCtx { + type Target = WasiCtxInner; + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/crates/wasi-common/src/file.rs b/crates/wasi-common/src/file.rs index b76278373a65..828307255e56 100644 --- a/crates/wasi-common/src/file.rs +++ b/crates/wasi-common/src/file.rs @@ -64,7 +64,7 @@ pub trait WasiFile: Send + Sync { Ok(FdFlags::empty()) } - async fn set_fdflags(&self, _flags: FdFlags) -> Result<(), Error> { + async fn set_fdflags(&mut self, _flags: FdFlags) -> Result<(), Error> { Err(Error::badf()) } @@ -217,11 +217,15 @@ pub struct Filestat { pub(crate) trait TableFileExt { fn get_file(&self, fd: u32) -> Result, Error>; + fn get_file_mut(&mut self, fd: u32) -> Result<&mut FileEntry, Error>; } impl TableFileExt for crate::table::Table { fn get_file(&self, fd: u32) -> Result, Error> { self.get(fd) } + fn get_file_mut(&mut self, fd: u32) -> Result<&mut FileEntry, Error> { + self.get_mut(fd) + } } pub(crate) struct FileEntry { @@ -272,6 +276,7 @@ impl FileEntry { pub trait FileEntryExt { fn get_cap(&self, caps: FileCaps) -> Result<&dyn WasiFile, Error>; + fn get_cap_mut(&mut self, caps: FileCaps) -> Result<&mut dyn WasiFile, Error>; } impl FileEntryExt for FileEntry { @@ -279,6 +284,10 @@ impl FileEntryExt for FileEntry { self.capable_of(caps)?; Ok(&*self.file) } + fn get_cap_mut(&mut self, caps: FileCaps) -> Result<&mut dyn WasiFile, Error> { + self.capable_of(caps)?; + Ok(&mut *self.file) + } } bitflags! { diff --git a/crates/wasi-common/src/snapshots/preview_1.rs b/crates/wasi-common/src/snapshots/preview_1.rs index 094270d51d3b..76ead4f090f7 100644 --- a/crates/wasi-common/src/snapshots/preview_1.rs +++ b/crates/wasi-common/src/snapshots/preview_1.rs @@ -189,11 +189,16 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { fd: types::Fd, flags: types::Fdflags, ) -> Result<(), Error> { - self.table() - .get_file(u32::from(fd))? - .get_cap(FileCaps::FDSTAT_SET_FLAGS)? - .set_fdflags(FdFlags::from(flags)) - .await + if let Some(table) = self.table_mut() { + table + .get_file_mut(u32::from(fd))? + .get_cap_mut(FileCaps::FDSTAT_SET_FLAGS)? + .set_fdflags(FdFlags::from(flags)) + .await + } else { + log::warn!("`fd_fdstat_set_flags` does not work with wasi-threads enabled; see https://github.com/bytecodealliance/wasmtime/issues/5643"); + Err(Error::invalid_argument()) + } } async fn fd_fdstat_set_rights( @@ -1110,7 +1115,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { while copied < buf.len() { let len = (buf.len() - copied).min(MAX_SHARED_BUFFER_SIZE as u32); let mut tmp = vec![0; len as usize]; - self.random.try_fill_bytes(&mut tmp)?; + self.random.lock().unwrap().try_fill_bytes(&mut tmp)?; let dest = buf .get_range(copied..copied + len) .unwrap() @@ -1122,7 +1127,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { // If the Wasm memory is non-shared, copy directly into the linear // memory. let mem = &mut buf.as_slice_mut()?.unwrap(); - self.random.try_fill_bytes(mem)?; + self.random.lock().unwrap().try_fill_bytes(mem)?; } Ok(()) } diff --git a/crates/wasi-common/src/table.rs b/crates/wasi-common/src/table.rs index a92fefff6a2a..40069636786e 100644 --- a/crates/wasi-common/src/table.rs +++ b/crates/wasi-common/src/table.rs @@ -76,6 +76,22 @@ impl Table { } } + /// Get a mutable reference to a resource of a given type at a given index. + /// Only one such reference can be borrowed at any given time. + pub fn get_mut(&mut self, key: u32) -> Result<&mut T, Error> { + let entry = match self.0.get_mut().unwrap().map.get_mut(&key) { + Some(entry) => entry, + None => return Err(Error::badf().context("key not in table")), + }; + let entry = match Arc::get_mut(entry) { + Some(entry) => entry, + None => return Err(Error::badf().context("cannot mutably borrow shared file")), + }; + entry + .downcast_mut::() + .ok_or_else(|| Error::badf().context("element is a different type")) + } + /// Remove a resource at a given index from the table. Returns the resource /// if it was present. pub fn delete(&self, key: u32) -> Option> { diff --git a/crates/wasi-common/tokio/src/file.rs b/crates/wasi-common/tokio/src/file.rs index 91e70aa58397..0217e315b783 100644 --- a/crates/wasi-common/tokio/src/file.rs +++ b/crates/wasi-common/tokio/src/file.rs @@ -116,7 +116,7 @@ macro_rules! wasi_file_impl { async fn get_fdflags(&self) -> Result { block_on_dummy_executor(|| self.0.get_fdflags()) } - async fn set_fdflags(&self, fdflags: FdFlags) -> Result<(), Error> { + async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> { block_on_dummy_executor(|| self.0.set_fdflags(fdflags)) } async fn get_filestat(&self) -> Result { diff --git a/src/commands/run.rs b/src/commands/run.rs index 1ea33fe9f820..0c5ba2e29f0c 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -178,7 +178,7 @@ impl RunCommand { // Read the wasm module binary either as `*.wat` or a raw binary. let module = self.load_module(linker.engine(), &self.module)?; - let mut host = Arc::new(Host::default()); + let mut host = Host::default(); populate_with_wasi( &mut host, &mut linker, @@ -290,8 +290,8 @@ impl RunCommand { fn load_main_module( &self, - store: &mut Store>, - linker: &mut Linker>, + store: &mut Store, + linker: &mut Linker, module: Module, ) -> Result<()> { if let Some(timeout) = self.wasm_timeout { @@ -324,8 +324,8 @@ impl RunCommand { fn invoke_export( &self, - store: &mut Store>, - linker: &Linker>, + store: &mut Store, + linker: &Linker, name: &str, ) -> Result<()> { let func = match linker @@ -339,12 +339,7 @@ impl RunCommand { self.invoke_func(store, func, Some(name)) } - fn invoke_func( - &self, - store: &mut Store>, - func: Func, - name: Option<&str>, - ) -> Result<()> { + fn invoke_func(&self, store: &mut Store, func: Func, name: Option<&str>) -> Result<()> { let ty = func.ty(&store); if ty.params().len() > 0 { eprintln!( @@ -419,22 +414,22 @@ impl RunCommand { } } -#[derive(Default)] +#[derive(Default, Clone)] struct Host { wasi: Option, #[cfg(feature = "wasi-crypto")] - wasi_crypto: Option, + wasi_crypto: Option>, #[cfg(feature = "wasi-nn")] - wasi_nn: Option, + wasi_nn: Option>, #[cfg(feature = "wasi-threads")] - wasi_threads: Option>>, + wasi_threads: Option>>, } /// Populates the given `Linker` with WASI APIs. fn populate_with_wasi( - host: &mut Arc, - linker: &mut Linker>, - module: Module, + host: &mut Host, + linker: &mut Linker, + _module: Module, preopen_dirs: Vec<(String, Dir)>, argv: &[String], vars: &[(String, String)], @@ -443,7 +438,7 @@ fn populate_with_wasi( mut tcplisten: Vec, ) -> Result<()> { if wasi_modules.wasi_common { - wasmtime_wasi::add_to_linker(linker, |host| host.wasi.as_ref().unwrap())?; + wasmtime_wasi::add_to_linker(linker, |host| host.wasi.as_mut().unwrap())?; let mut builder = WasiCtxBuilder::new(); builder = builder.inherit_stdio().args(argv)?.envs(vars)?; @@ -465,9 +460,7 @@ fn populate_with_wasi( builder = builder.preopened_dir(dir, name)?; } - Arc::get_mut(host) - .expect("there must be no other host references during setup") - .wasi = Some(builder.build()); + host.wasi = Some(builder.build()); } if wasi_modules.wasi_crypto { @@ -478,14 +471,11 @@ fn populate_with_wasi( #[cfg(feature = "wasi-crypto")] { wasmtime_wasi_crypto::add_to_linker(linker, |host| { - Arc::get_mut(host) - .wasi_crypto + host.wasi_crypto .as_mut() .expect("wasi-crypto is not implemented with multi-threading support") })?; - Arc::get_mut(host) - .expect("there must be no other host references during setup") - .wasi_crypto = Some(WasiCryptoCtx::new()); + host.wasi_crypto = Some(Arc::new(WasiCryptoCtx::new())); } } @@ -497,14 +487,11 @@ fn populate_with_wasi( #[cfg(feature = "wasi-nn")] { wasmtime_wasi_nn::add_to_linker(linker, |host| { - Arc::get_mut(host) - .wasi_nn + host.wasi_nn .as_mut() .expect("wasi-nn is not implemented with multi-threading support") })?; - Arc::get_mut(host) - .expect("there must be no other host references during setup") - .wasi_nn = Some(WasiNnCtx::new()?); + host.wasi_nn = Some(Arc::new(WasiNnCtx::new()?)); } } @@ -515,12 +502,10 @@ fn populate_with_wasi( } #[cfg(feature = "wasi-threads")] { - wasmtime_wasi_threads::add_to_linker(linker, &module, |host| { + wasmtime_wasi_threads::add_to_linker(linker, &_module, |host| { host.wasi_threads.as_ref().unwrap() })?; - Arc::get_mut(host) - .expect("there must be no other host references during setup") - .wasi_threads = Some(WasiThreadsCtx::new(module, linker.clone())); + host.wasi_threads = Some(Arc::new(WasiThreadsCtx::new(_module, linker.clone()))); } }