diff --git a/ci/run-wasi-nn-example.sh b/ci/run-wasi-nn-example.sh index b8012120d3fa..09947da16d65 100755 --- a/ci/run-wasi-nn-example.sh +++ b/ci/run-wasi-nn-example.sh @@ -1,10 +1,11 @@ #!/bin/bash -# The following script demonstrates how to execute a machine learning inference using the wasi-nn -# module optionally compiled into Wasmtime. Calling it will download the necessary model and tensor -# files stored separately in $FIXTURE into $TMP_DIR (optionally pass a directory with existing files -# as the first argument to re-try the script). Then, it will compile the example code in -# crates/wasi-nn/tests/example into a Wasm file that is subsequently executed with the Wasmtime CLI. +# The following script demonstrates how to execute a machine learning inference +# using the wasi-nn module optionally compiled into Wasmtime. Calling it will +# download the necessary model and tensor files stored separately in $FIXTURE +# into $TMP_DIR (optionally pass a directory with existing files as the first +# argument to re-try the script). Then, it will compile and run several examples +# in the Wasmtime CLI. set -e WASMTIME_DIR=$(dirname "$0" | xargs dirname) FIXTURE=https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet @@ -18,7 +19,12 @@ else REMOVE_TMP_DIR=0 fi -# Build Wasmtime with wasi-nn enabled; we attempt this first to avoid extra work if the build fails. +# One of the examples expects to be in a specifically-named directory. +mkdir -p $TMP_DIR/mobilenet +TMP_DIR=$TMP_DIR/mobilenet + +# Build Wasmtime with wasi-nn enabled; we attempt this first to avoid extra work +# if the build fails. cargo build -p wasmtime-cli --features wasi-nn # Download all necessary test fixtures to the temporary directory. @@ -26,16 +32,27 @@ wget --no-clobber $FIXTURE/mobilenet.bin --output-document=$TMP_DIR/model.bin wget --no-clobber $FIXTURE/mobilenet.xml --output-document=$TMP_DIR/model.xml wget --no-clobber $FIXTURE/tensor-1x224x224x3-f32.bgr --output-document=$TMP_DIR/tensor.bgr -# Now build an example that uses the wasi-nn API. +# Now build an example that uses the wasi-nn API. Run the example in Wasmtime +# (note that the example uses `fixture` as the expected location of the +# model/tensor files). pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example cargo build --release --target=wasm32-wasi cp target/wasm32-wasi/release/wasi-nn-example.wasm $TMP_DIR popd +cargo run -- run --mapdir fixture::$TMP_DIR \ + --wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example.wasm -# Run the example in Wasmtime (note that the example uses `fixture` as the expected location of the model/tensor files). -cargo run -- run --mapdir fixture::$TMP_DIR --wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example.wasm +# Build and run another example, this time using Wasmtime's graph flag to +# preload the model. +pushd $WASMTIME_DIR/crates/wasi-nn/examples/classification-example-named +cargo build --release --target=wasm32-wasi +cp target/wasm32-wasi/release/wasi-nn-example-named.wasm $TMP_DIR +popd +cargo run -- run --mapdir fixture::$TMP_DIR --wasi-nn-graph openvino::$TMP_DIR \ + --wasi-modules=experimental-wasi-nn $TMP_DIR/wasi-nn-example-named.wasm -# Clean up the temporary directory only if it was not specified (users may want to keep the directory around). +# Clean up the temporary directory only if it was not specified (users may want +# to keep the directory around). if [[ $REMOVE_TMP_DIR -eq 1 ]]; then rm -rf $TMP_DIR fi diff --git a/crates/wasi-nn/examples/classification-example-named/Cargo.lock b/crates/wasi-nn/examples/classification-example-named/Cargo.lock new file mode 100644 index 000000000000..6a1c9eb0f62f --- /dev/null +++ b/crates/wasi-nn/examples/classification-example-named/Cargo.lock @@ -0,0 +1,74 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "proc-macro2" +version = "1.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9207952ae1a003f42d3d5e892dac3c6ba42aa6ac0c79a6a91a2b5cb4253e75c" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1728216d3244de4f14f14f8c15c79be1a7c67867d28d69b719690e2a19fb445" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" + +[[package]] +name = "wasi-nn" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03d01b90f0cca3f19682e90e1bc3f5e3e441031e19e56ce7dbf034f3b3597552" +dependencies = [ + "thiserror", +] + +[[package]] +name = "wasi-nn-example-named" +version = "0.0.0" +dependencies = [ + "wasi-nn", +] diff --git a/crates/wasi-nn/examples/classification-example-named/Cargo.toml b/crates/wasi-nn/examples/classification-example-named/Cargo.toml new file mode 100644 index 000000000000..b4653659bd3d --- /dev/null +++ b/crates/wasi-nn/examples/classification-example-named/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "wasi-nn-example-named" +version = "0.0.0" +authors = ["The Wasmtime Project Developers"] +readme = "README.md" +edition = "2021" +publish = false + +[dependencies] +wasi-nn = "0.5.0" + +# This crate is built with the wasm32-wasi target, so it's separate +# from the main Wasmtime build, so use this directive to exclude it +# from the parent directory's workspace. +[workspace] diff --git a/crates/wasi-nn/examples/classification-example-named/README.md b/crates/wasi-nn/examples/classification-example-named/README.md new file mode 100644 index 000000000000..aa56ad0cbaf7 --- /dev/null +++ b/crates/wasi-nn/examples/classification-example-named/README.md @@ -0,0 +1,2 @@ +This example project demonstrates using the `wasi-nn` API to perform ML inference. It consists of Rust code that is +built using the `wasm32-wasi` target. See `ci/run-wasi-nn-example.sh` for how this is used. diff --git a/crates/wasi-nn/examples/classification-example-named/src/main.rs b/crates/wasi-nn/examples/classification-example-named/src/main.rs new file mode 100644 index 000000000000..f9bc7e1906be --- /dev/null +++ b/crates/wasi-nn/examples/classification-example-named/src/main.rs @@ -0,0 +1,53 @@ +use std::fs; +use wasi_nn::*; + +pub fn main() { + let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) + .build_from_cache("mobilenet") + .unwrap(); + println!("Loaded a graph: {:?}", graph); + + let mut context = graph.init_execution_context().unwrap(); + println!("Created an execution context: {:?}", context); + + // Load a tensor that precisely matches the graph input tensor (see + // `fixture/frozen_inference_graph.xml`). + let tensor_data = fs::read("fixture/tensor.bgr").unwrap(); + println!("Read input tensor, size in bytes: {}", tensor_data.len()); + context + .set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data) + .unwrap(); + + // Execute the inference. + context.compute().unwrap(); + println!("Executed graph inference"); + + // Retrieve the output. + let mut output_buffer = vec![0f32; 1001]; + context.get_output(0, &mut output_buffer[..]).unwrap(); + + println!( + "Found results, sorted top 5: {:?}", + &sort_results(&output_buffer)[..5] + ) +} + +// Sort the buffer of probabilities. The graph places the match probability for each class at the +// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert +// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output +// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense +// (e.g. 763 = "revolver" vs 762 = "restaurant") +fn sort_results(buffer: &[f32]) -> Vec { + let mut results: Vec = buffer + .iter() + .skip(1) + .enumerate() + .map(|(c, p)| InferenceResult(c, *p)) + .collect(); + results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + results +} + +// A wrapper for class ID and match probabilities. +#[derive(Debug, PartialEq)] +struct InferenceResult(usize, f32); diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index ad929a9a74e0..5912b26b606d 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -7,6 +7,7 @@ mod openvino; use self::openvino::OpenvinoBackend; use crate::wit::types::{ExecutionTarget, Tensor}; use crate::{ExecutionContext, Graph}; +use std::{error::Error, fmt, path::Path, str::FromStr}; use thiserror::Error; use wiggle::GuestError; @@ -15,16 +16,28 @@ pub fn list() -> Vec<(BackendKind, Box)> { vec![(BackendKind::OpenVINO, Box::new(OpenvinoBackend::default()))] } -/// A [Backend] contains the necessary state to load [BackendGraph]s. +/// A [Backend] contains the necessary state to load [Graph]s. pub trait Backend: Send + Sync { fn name(&self) -> &str; fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result; + fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir>; +} + +/// Some [Backend]s support loading a [Graph] from a directory on the +/// filesystem; this is not a general requirement for backends but is useful for +/// the Wasmtime CLI. +pub trait BackendFromDir: Backend { + fn load_from_dir( + &mut self, + builders: &Path, + target: ExecutionTarget, + ) -> Result; } /// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing /// implementation for a [crate::witx::types::Graph]. pub trait BackendGraph: Send + Sync { - fn init_execution_context(&mut self) -> Result; + fn init_execution_context(&self) -> Result; } /// A [BackendExecutionContext] performs the actual inference; this is the @@ -53,3 +66,20 @@ pub enum BackendError { pub enum BackendKind { OpenVINO, } +impl FromStr for BackendKind { + type Err = BackendKindParseError; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "openvino" => Ok(BackendKind::OpenVINO), + _ => Err(BackendKindParseError(s.into())), + } + } +} +#[derive(Debug)] +pub struct BackendKindParseError(String); +impl fmt::Display for BackendKindParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "unknown backend: {}", self.0) + } +} +impl Error for BackendKindParseError {} diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index 93f51771c95f..a478ec9a4bdc 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,10 +1,11 @@ //! Implements a `wasi-nn` [`Backend`] using OpenVINO. -use super::{Backend, BackendError, BackendExecutionContext, BackendGraph}; +use super::{Backend, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph}; use crate::wit::types::{ExecutionTarget, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use std::{fs::File, io::Read, path::Path}; #[derive(Default)] pub(crate) struct OpenvinoBackend(Option); @@ -51,20 +52,42 @@ impl Backend for OpenvinoBackend { let exec_network = core.load_network(&cnn_network, map_execution_target_to_string(target))?; - let box_: Box = - Box::new(OpenvinoGraph(Arc::new(cnn_network), exec_network)); + let box_: Box = Box::new(OpenvinoGraph( + Arc::new(cnn_network), + Arc::new(Mutex::new(exec_network)), + )); Ok(box_.into()) } + + fn as_dir_loadable(&mut self) -> Option<&mut dyn BackendFromDir> { + Some(self) + } } -struct OpenvinoGraph(Arc, openvino::ExecutableNetwork); +impl BackendFromDir for OpenvinoBackend { + fn load_from_dir( + &mut self, + path: &Path, + target: ExecutionTarget, + ) -> Result { + let model = read(&path.join("model.xml"))?; + let weights = read(&path.join("model.bin"))?; + self.load(&[&model, &weights], target) + } +} + +struct OpenvinoGraph( + Arc, + Arc>, +); unsafe impl Send for OpenvinoGraph {} unsafe impl Sync for OpenvinoGraph {} impl BackendGraph for OpenvinoGraph { - fn init_execution_context(&mut self) -> Result { - let infer_request = self.1.create_infer_request()?; + fn init_execution_context(&self) -> Result { + let mut network = self.1.lock().unwrap(); + let infer_request = network.create_infer_request()?; let box_: Box = Box::new(OpenvinoExecutionContext(self.0.clone(), infer_request)); Ok(box_.into()) @@ -145,3 +168,11 @@ fn map_tensor_type_to_precision(tensor_type: TensorType) -> openvino::Precision TensorType::Bf16 => todo!("not yet supported in `openvino` bindings"), } } + +/// Read a file into a byte vector. +fn read(path: &Path) -> anyhow::Result> { + let mut file = File::open(path)?; + let mut buffer = vec![]; + file.read_to_end(&mut buffer)?; + Ok(buffer) +} diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs index e961938a8f3d..913205c6d81c 100644 --- a/crates/wasi-nn/src/ctx.rs +++ b/crates/wasi-nn/src/ctx.rs @@ -1,38 +1,59 @@ //! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. -use crate::backend::{self, Backend, BackendError, BackendKind}; +use crate::backend::{Backend, BackendError, BackendKind}; use crate::wit::types::GraphEncoding; -use crate::{ExecutionContext, Graph}; -use std::{collections::HashMap, hash::Hash}; +use crate::{ExecutionContext, Graph, GraphRegistry, InMemoryRegistry}; +use anyhow::anyhow; +use std::{collections::HashMap, hash::Hash, path::Path}; use thiserror::Error; use wiggle::GuestError; type Backends = HashMap>; +type Registry = Box; type GraphId = u32; type GraphExecutionContextId = u32; +type BackendName = String; +type GraphDirectory = String; + +/// Construct an in-memory registry from the available backends and a list of +/// `(, )`. This assumes graphs can be loaded +/// from a local directory, which is a safe assumption currently for the current +/// model types. +pub fn preload( + preload_graphs: &[(BackendName, GraphDirectory)], +) -> anyhow::Result<(Backends, Registry)> { + let mut backends: HashMap<_, _> = crate::backend::list().into_iter().collect(); + let mut registry = InMemoryRegistry::new(); + for (kind, path) in preload_graphs { + let backend = backends + .get_mut(&kind.parse()?) + .ok_or(anyhow!("unsupported backend: {}", kind))? + .as_dir_loadable() + .ok_or(anyhow!("{} does not support directory loading", kind))?; + registry.load(backend, Path::new(path))?; + } + Ok((backends, Box::new(registry))) +} /// Capture the state necessary for calling into the backend ML libraries. pub struct WasiNnCtx { pub(crate) backends: Backends, + pub(crate) registry: Registry, pub(crate) graphs: Table, pub(crate) executions: Table, } impl WasiNnCtx { /// Make a new context from the default state. - pub fn new(backends: Backends) -> Self { + pub fn new(backends: Backends, registry: Registry) -> Self { Self { backends, + registry, graphs: Table::default(), executions: Table::default(), } } } -impl Default for WasiNnCtx { - fn default() -> Self { - WasiNnCtx::new(backend::list().into_iter().collect()) - } -} /// Possible errors while interacting with [WasiNnCtx]. #[derive(Debug, Error)] @@ -90,6 +111,10 @@ where key } + pub fn get(&self, key: K) -> Option<&V> { + self.entries.get(&key) + } + pub fn get_mut(&mut self, key: K) -> Option<&mut V> { self.entries.get_mut(&key) } @@ -106,7 +131,14 @@ mod test { use super::*; #[test] - fn instantiate() { - WasiNnCtx::default(); + fn example() { + struct FakeRegistry; + impl GraphRegistry for FakeRegistry { + fn get_mut(&mut self, _: &str) -> Option<&mut Graph> { + None + } + } + + let ctx = WasiNnCtx::new(HashMap::new(), Box::new(FakeRegistry)); } } diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 1abd6c0b1372..f5c9bfe641d8 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,15 +1,20 @@ mod backend; mod ctx; +mod registry; -pub use ctx::WasiNnCtx; +pub use ctx::{preload, WasiNnCtx}; +pub use registry::{GraphRegistry, InMemoryRegistry}; pub mod wit; pub mod witx; +use std::sync::Arc; + /// A backend-defined graph (i.e., ML model). -pub struct Graph(Box); +#[derive(Clone)] +pub struct Graph(Arc); impl From> for Graph { fn from(value: Box) -> Self { - Self(value) + Self(value.into()) } } impl std::ops::Deref for Graph { @@ -18,11 +23,6 @@ impl std::ops::Deref for Graph { self.0.as_ref() } } -impl std::ops::DerefMut for Graph { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.as_mut() - } -} /// A backend-defined execution context. pub struct ExecutionContext(Box); diff --git a/crates/wasi-nn/src/registry/in_memory.rs b/crates/wasi-nn/src/registry/in_memory.rs new file mode 100644 index 000000000000..b008f7f43684 --- /dev/null +++ b/crates/wasi-nn/src/registry/in_memory.rs @@ -0,0 +1,43 @@ +//! Implement a [`GraphRegistry`] with a hash map. + +use super::{Graph, GraphRegistry}; +use crate::backend::BackendFromDir; +use crate::wit::types::ExecutionTarget; +use anyhow::{anyhow, bail}; +use std::{collections::HashMap, path::Path}; + +pub struct InMemoryRegistry(HashMap); +impl InMemoryRegistry { + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Load a graph from the files contained in the `path` directory. + /// + /// This expects the backend to know how to load graphs (i.e., ML model) + /// from a directory. The name used in the registry is the directory's last + /// suffix: if the backend can find the files it expects in `/my/model/foo`, + /// the registry will contain a new graph named `foo`. + pub fn load(&mut self, backend: &mut dyn BackendFromDir, path: &Path) -> anyhow::Result<()> { + if !path.is_dir() { + bail!( + "preload directory is not a valid directory: {}", + path.display() + ); + } + let name = path + .file_name() + .map(|s| s.to_string_lossy()) + .ok_or(anyhow!("no file name in path"))?; + + let graph = backend.load_from_dir(path, ExecutionTarget::Cpu)?; + self.0.insert(name.into_owned(), graph); + Ok(()) + } +} + +impl GraphRegistry for InMemoryRegistry { + fn get_mut(&mut self, name: &str) -> Option<&mut Graph> { + self.0.get_mut(name) + } +} diff --git a/crates/wasi-nn/src/registry/mod.rs b/crates/wasi-nn/src/registry/mod.rs new file mode 100644 index 000000000000..83f88e4dca0e --- /dev/null +++ b/crates/wasi-nn/src/registry/mod.rs @@ -0,0 +1,16 @@ +//! Define the registry API. +//! +//! A [`GraphRegistry`] is place to store backend graphs so they can be loaded +//! by name. This API does not mandate how a graph is loaded or how it must be +//! stored--it could be stored remotely and rematerialized when needed, e.g. A +//! naive in-memory implementation, [`InMemoryRegistry`] is provided for use +//! with the Wasmtime CLI. + +mod in_memory; + +use crate::Graph; +pub use in_memory::InMemoryRegistry; + +pub trait GraphRegistry: Send + Sync { + fn get_mut(&mut self, name: &str) -> Option<&mut Graph>; +} diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index 25510374d5cf..2b2032bc3109 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -53,9 +53,14 @@ impl gen::graph::Host for WasiNnCtx { fn load_by_name( &mut self, - _name: String, + name: String, ) -> wasmtime::Result> { - todo!() + if let Some(graph) = self.registry.get_mut(&name) { + let graph_id = self.graphs.insert(graph.clone().into()); + Ok(Ok(graph_id)) + } else { + return Err(UsageError::NotFound(name.to_string()).into()); + } } } @@ -67,7 +72,7 @@ impl gen::inference::Host for WasiNnCtx { &mut self, graph_id: gen::graph::Graph, ) -> wasmtime::Result> { - let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) { + let exec_context = if let Some(graph) = self.graphs.get(graph_id) { graph.init_execution_context()? } else { return Err(UsageError::InvalidGraphHandle.into()); diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index b339d9a8d389..bf5a57b8e980 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -78,8 +78,14 @@ impl<'a> gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { Ok(graph_id.into()) } - fn load_by_name<'b>(&mut self, _name: &wiggle::GuestPtr<'b, str>) -> Result { - todo!() + fn load_by_name<'b>(&mut self, name: &wiggle::GuestPtr<'b, str>) -> Result { + let name = name.as_str()?.unwrap(); + if let Some(graph) = self.registry.get_mut(&name) { + let graph_id = self.graphs.insert(graph.clone().into()); + Ok(graph_id.into()) + } else { + return Err(UsageError::NotFound(name.to_string()).into()); + } } fn init_execution_context( diff --git a/src/commands/run.rs b/src/commands/run.rs index 89a050e1b105..3a3f080c9dca 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -51,6 +51,14 @@ fn parse_map_dirs(s: &str) -> Result<(String, String)> { Ok((parts[0].into(), parts[1].into())) } +fn parse_graphs(s: &str) -> Result<(String, String)> { + let parts: Vec<&str> = s.split("::").collect(); + if parts.len() != 2 { + bail!("must contain exactly one double colon ('::')"); + } + Ok((parts[0].into(), parts[1].into())) +} + fn parse_dur(s: &str) -> Result { // assume an integer without a unit specified is a number of seconds ... if let Ok(val) = s.parse() { @@ -158,6 +166,17 @@ pub struct RunCommand { #[clap(long = "mapdir", number_of_values = 1, value_name = "GUEST_DIR::HOST_DIR", value_parser = parse_map_dirs)] map_dirs: Vec<(String, String)>, + /// Pre-load machine learning graphs (i.e., models) for use by wasi-nn. + /// + /// Each use of the flag will preload a ML model from the host directory + /// using the given model encoding. The model will be mapped to the + /// directory name: e.g., `--wasi-nn-graph openvino:/foo/bar` will preload + /// an OpenVINO model named `bar`. Note that which model encodings are + /// available is dependent on the backends implemented in the + /// `wasmtime_wasi_nn` crate. + #[clap(long = "wasi-nn-graph", value_name = "FORMAT::HOST_DIR", value_parser = parse_graphs)] + graphs: Vec<(String, String)>, + /// Load the given WebAssembly module before the main module #[clap( long = "preload", @@ -922,7 +941,8 @@ impl RunCommand { })?; } } - store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::default())); + let (backends, registry) = wasmtime_wasi_nn::preload(&self.graphs)?; + store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new(backends, registry))); } }