diff --git a/Cargo.lock b/Cargo.lock index eccee7eb5785..47e695e49d79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,15 +371,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "castaway" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc" -dependencies = [ - "rustversion", -] - [[package]] name = "cc" version = "1.0.83" @@ -486,19 +477,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" -[[package]] -name = "compact_str" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "ryu", - "static_assertions", -] - [[package]] name = "component-fuzz-util" version = "0.0.0" @@ -1908,9 +1886,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] @@ -2028,21 +2006,22 @@ dependencies = [ [[package]] name = "ort" -version = "2.0.0-rc.0" +version = "2.0.0-rc.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd" +checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" dependencies = [ - "compact_str", + "js-sys", "ort-sys", "thiserror", "tracing", + "web-sys", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.0" +version = "2.0.0-rc.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46" +checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" dependencies = [ "flate2", "sha2", @@ -3936,9 +3915,10 @@ dependencies = [ "test-programs-artifacts", "thiserror", "tracing", + "tracing-subscriber", "walkdir", - "wasi-common", "wasmtime", + "wasmtime-wasi", "wiggle", "windows", ] @@ -4024,6 +4004,16 @@ dependencies = [ "wast 211.0.1", ] +[[package]] +name = "web-sys" +version = "0.3.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.1" diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 30c1cbb9283c..53ce9ad80cb2 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -36,5 +36,9 @@ cp -r $dst crates/wasi-http/wit # slightly different than above. repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn revision=e2310b -curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx +# TODO: the in-tree `wasi-nn` implementation does not yet fully support the +# latest WIT specification on `main`. To create a baseline for moving forward, +# the in-tree WIT incorporates some but not all of the upstream changes. This +# TODO can be removed once the implementation catches up with the spec. +# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit diff --git a/crates/bench-api/src/lib.rs b/crates/bench-api/src/lib.rs index 166db3dc2fe4..9b89faddcbda 100644 --- a/crates/bench-api/src/lib.rs +++ b/crates/bench-api/src/lib.rs @@ -418,7 +418,7 @@ struct BenchState { struct HostState { wasi: WasiCtx, #[cfg(feature = "wasi-nn")] - wasi_nn: wasmtime_wasi_nn::WasiNnCtx, + wasi_nn: wasmtime_wasi_nn::witx::WasiNnCtx, } impl BenchState { @@ -509,7 +509,7 @@ impl BenchState { #[cfg(feature = "wasi-nn")] wasi_nn: { let (backends, registry) = wasmtime_wasi_nn::preload(&[])?; - wasmtime_wasi_nn::WasiNnCtx::new(backends, registry) + wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry) }, }; diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index 20209eaf6de6..7fd653f9b72a 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -90,7 +90,10 @@ fn build_and_generate_tests() { } // Generate a component from each test. - if kind == "nn" || target == "dwarf_imported_memory" || target == "dwarf_shared_memory" { + if target == "dwarf_imported_memory" + || target == "dwarf_shared_memory" + || target.starts_with("nn_witx") + { continue; } let adapter = match target.as_str() { diff --git a/crates/test-programs/src/bin/nn_image_classification_winml.rs b/crates/test-programs/src/bin/nn_image_classification_winml.rs deleted file mode 100644 index 0dc7e8843525..000000000000 --- a/crates/test-programs/src/bin/nn_image_classification_winml.rs +++ /dev/null @@ -1,16 +0,0 @@ -use anyhow::{Context, Result}; -use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; - -pub fn main() -> Result<()> { - let graph = GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU) - .build_from_cache("mobilenet")?; - let tensor = fs::read("fixture/kitten.rgb") - .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; - let top_five = &sort_results(&results)[..5]; - println!("found results, sorted top 5: {:?}", top_five); - assert_eq!(top_five[0].class_id(), 284); - Ok(()) -} diff --git a/crates/test-programs/src/bin/nn_image_classification_onnx.rs b/crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs similarity index 67% rename from crates/test-programs/src/bin/nn_image_classification_onnx.rs rename to crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs index abb77d0e7339..5900e664afde 100644 --- a/crates/test-programs/src/bin/nn_image_classification_onnx.rs +++ b/crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs @@ -1,18 +1,20 @@ use anyhow::{Context, Result}; use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; +use test_programs::nn::{sort_results, wit}; pub fn main() -> Result<()> { let model = fs::read("fixture/model.onnx") .context("the model file to be mapped to the fixture directory")?; - let graph = - GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?; + let graph = wit::load( + &[model], + wit::GraphEncoding::Onnx, + wit::ExecutionTarget::Cpu, + )?; let tensor = fs::read("fixture/000000062808.rgb") .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; + let results = wit::classify(graph, ("input", tensor), "output")?; let top_five = &sort_results(&results)[..5]; - // 963 is meat loaf, meatloaf. + // 963 is "meat loaf, meatloaf." // https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963 assert_eq!(top_five[0].class_id(), 963); println!("found results, sorted top 5: {:?}", top_five); diff --git a/crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs b/crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs new file mode 100644 index 000000000000..52bb6eb5967c --- /dev/null +++ b/crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs @@ -0,0 +1,25 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, wit}; + +pub fn main() -> Result<()> { + let xml = fs::read("fixture/model.xml") + .context("the model file to be mapped to the fixture directory")?; + let weights = fs::read("fixture/model.bin") + .context("the weights file to be mapped to the fixture directory")?; + let graph = wit::load( + &[xml, weights], + wit::GraphEncoding::Openvino, + wit::ExecutionTarget::Cpu, + )?; + let tensor = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = wit::classify( + graph, + ("input", tensor), + "MobilenetV2/Predictions/Reshape_1", + )?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs b/crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs new file mode 100644 index 000000000000..482c77043206 --- /dev/null +++ b/crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs @@ -0,0 +1,17 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, wit}; + +pub fn main() -> Result<()> { + let graph = wit::load_by_name("fixtures")?; + let tensor: Vec = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = wit::classify( + graph, + ("input", tensor), + "MobilenetV2/Predictions/Reshape_1", + )?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_image_classification_named.rs b/crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs similarity index 53% rename from crates/test-programs/src/bin/nn_image_classification_named.rs rename to crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs index 9b75a5afb4c6..f3840062e207 100644 --- a/crates/test-programs/src/bin/nn_image_classification_named.rs +++ b/crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs @@ -1,15 +1,14 @@ use anyhow::{Context, Result}; use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; +use test_programs::nn::{sort_results, wit}; pub fn main() -> Result<()> { - let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) - .build_from_cache("fixtures")?; + let graph = wit::load_by_name("mobilenet")?; let tensor = fs::read("fixture/tensor.bgr") .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; + let results = wit::classify(graph, ("input", tensor), "output")?; let top_five = &sort_results(&results)[..5]; println!("found results, sorted top 5: {:?}", top_five); + assert_eq!(top_five[0].class_id(), 284); Ok(()) } diff --git a/crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs b/crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs new file mode 100644 index 000000000000..fc4991ca44ff --- /dev/null +++ b/crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs @@ -0,0 +1,22 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, witx}; + +pub fn main() -> Result<()> { + let model = fs::read("fixture/model.onnx") + .context("the model file to be mapped to the fixture directory")?; + let graph = witx::load( + &[&model], + witx::GraphEncoding::Onnx, + witx::ExecutionTarget::CPU, + )?; + let tensor = fs::read("fixture/000000062808.rgb") + .context("the tensor file to be mapped to the fixture directory")?; + let results = witx::classify(graph, tensor)?; + let top_five = &sort_results(&results)[..5]; + // 963 is "meat loaf, meatloaf." + // https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963 + assert_eq!(top_five[0].class_id(), 963); + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_image_classification.rs b/crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs similarity index 66% rename from crates/test-programs/src/bin/nn_image_classification.rs rename to crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs index 5815503c3f76..5d557e3b8274 100644 --- a/crates/test-programs/src/bin/nn_image_classification.rs +++ b/crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs @@ -1,18 +1,20 @@ use anyhow::{Context, Result}; use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; +use test_programs::nn::{sort_results, witx}; pub fn main() -> Result<()> { let xml = fs::read("fixture/model.xml") .context("the model file to be mapped to the fixture directory")?; let weights = fs::read("fixture/model.bin") .context("the weights file to be mapped to the fixture directory")?; - let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) - .build_from_bytes([&xml, &weights])?; + let graph = witx::load( + &[&xml, &weights], + witx::GraphEncoding::Openvino, + witx::ExecutionTarget::CPU, + )?; let tensor = fs::read("fixture/tensor.bgr") .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; + let results = witx::classify(graph, tensor)?; let top_five = &sort_results(&results)[..5]; println!("found results, sorted top 5: {:?}", top_five); Ok(()) diff --git a/crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs b/crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs new file mode 100644 index 000000000000..d91c78a5c7b5 --- /dev/null +++ b/crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs @@ -0,0 +1,17 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, witx}; + +pub fn main() -> Result<()> { + let graph = witx::load_by_name( + "fixtures", + witx::GraphEncoding::Openvino, + witx::ExecutionTarget::CPU, + )?; + let tensor: Vec = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = witx::classify(graph, tensor)?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs b/crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs new file mode 100644 index 000000000000..87808f32c2c3 --- /dev/null +++ b/crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs @@ -0,0 +1,18 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, witx}; + +pub fn main() -> Result<()> { + let graph = witx::load_by_name( + "mobilenet", + witx::GraphEncoding::Onnx, + witx::ExecutionTarget::CPU, + )?; + let tensor = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = witx::classify(graph, tensor)?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + assert_eq!(top_five[0].class_id(), 284); + Ok(()) +} diff --git a/crates/test-programs/src/nn.rs b/crates/test-programs/src/nn.rs index f3b54b460901..361a7c1f6282 100644 --- a/crates/test-programs/src/nn.rs +++ b/crates/test-programs/src/nn.rs @@ -1,39 +1,147 @@ -use anyhow::Result; -use std::time::Instant; -use wasi_nn::{Graph, TensorType}; - -/// Run a wasi-nn inference using a simple classifier model (single input, -/// single output). -pub fn classify(graph: Graph, tensor: Vec) -> Result> { - let mut context = graph.init_execution_context()?; - println!( - "[nn] created wasi-nn execution context with ID: {}", - context - ); - - // Many classifiers have a single input; currently, this test suite also - // uses tensors of the same shape, though this is not usually the case. - context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?; - println!("[nn] set input tensor: {} bytes", tensor.len()); - - let before = Instant::now(); - context.compute()?; - println!( - "[nn] executed graph inference in {} ms", - before.elapsed().as_millis() - ); - - // Many classifiers emit probabilities as floating point values; here we - // convert the raw bytes to `f32` knowing all models used here use that - // type. - let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::()]; - let num_bytes = context.get_output(0, &mut output_buffer)?; - println!("[nn] retrieved output tensor: {} bytes", num_bytes); - let output: Vec = output_buffer[..num_bytes] - .chunks(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect(); - Ok(output) +//! This module attempts to paper over the differences between the two +//! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and +//! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple +//! classifier, this exposes a high-level `classify` function to go along with +//! `load`, etc. +//! +//! This module exists solely for convenience--e.g., reduces test duplication. +//! In the future can be safely disposed of or altered as more tests are added. + +/// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the +/// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests. +pub mod wit { + use anyhow::{anyhow, Result}; + use std::time::Instant; + + // Generate the wasi-nn bindings based on the `*.wit` files. + wit_bindgen::generate!({ + path: "../wasi-nn/wit", + world: "ml", + default_bindings_module: "test_programs::ml" + }); + use self::wasi::nn::errors; + use self::wasi::nn::graph::{self, Graph}; + pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests. + use self::wasi::nn::tensor::{Tensor, TensorType}; + + /// Load a wasi-nn graph from a set of bytes. + pub fn load( + bytes: &[Vec], + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + graph::load(bytes, encoding, target).map_err(err_as_anyhow) + } + + /// Load a wasi-nn graph by name. + pub fn load_by_name(name: &str) -> Result { + graph::load_by_name(name).map_err(err_as_anyhow) + } + + /// Run a wasi-nn inference using a simple classifier model (single input, + /// single output). + pub fn classify(graph: Graph, input: (&str, Vec), output: &str) -> Result> { + let context = graph.init_execution_context().map_err(err_as_anyhow)?; + println!( + "[nn] created wasi-nn execution context with ID: {:?}", + context + ); + + // Many classifiers have a single input; currently, this test suite also + // uses tensors of the same shape, though this is not usually the case. + let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1); + context.set_input(input.0, tensor).map_err(err_as_anyhow)?; + println!("[nn] set input tensor: {} bytes", input.1.len()); + + let before = Instant::now(); + context.compute().map_err(err_as_anyhow)?; + println!( + "[nn] executed graph inference in {} ms", + before.elapsed().as_millis() + ); + + // Many classifiers emit probabilities as floating point values; here we + // convert the raw bytes to `f32` knowing all models used here use that + // type. + let output = context.get_output(output).map_err(err_as_anyhow)?; + println!( + "[nn] retrieved output tensor: {} bytes", + output.data().len() + ); + let output: Vec = output + .data() + .chunks(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + Ok(output) + } + + fn err_as_anyhow(e: errors::Error) -> anyhow::Error { + anyhow!("error: {e:?}") + } +} + +/// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based +/// tooling. This older API has been deprecated for the newer WIT-based API but +/// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs` +/// tests. +pub mod witx { + use anyhow::Result; + use std::time::Instant; + pub use wasi_nn::{ExecutionTarget, GraphEncoding}; + use wasi_nn::{Graph, GraphBuilder, TensorType}; + + /// Load a wasi-nn graph from a set of bytes. + pub fn load( + bytes: &[&[u8]], + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?) + } + + /// Load a wasi-nn graph by name. + pub fn load_by_name( + name: &str, + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?) + } + + /// Run a wasi-nn inference using a simple classifier model (single input, + /// single output). + pub fn classify(graph: Graph, tensor: Vec) -> Result> { + let mut context = graph.init_execution_context()?; + println!( + "[nn] created wasi-nn execution context with ID: {}", + context + ); + + // Many classifiers have a single input; currently, this test suite also + // uses tensors of the same shape, though this is not usually the case. + context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?; + println!("[nn] set input tensor: {} bytes", tensor.len()); + + let before = Instant::now(); + context.compute()?; + println!( + "[nn] executed graph inference in {} ms", + before.elapsed().as_millis() + ); + + // Many classifiers emit probabilities as floating point values; here we + // convert the raw bytes to `f32` knowing all models used here use that + // type. + let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::()]; + let num_bytes = context.get_output(0, &mut output_buffer)?; + println!("[nn] retrieved output tensor: {} bytes", num_bytes); + let output: Vec = output_buffer[..num_bytes] + .chunks(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + Ok(output) + } } /// Sort some classification probabilities. diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index 7390c1e33146..a5ace788d03a 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -20,7 +20,11 @@ anyhow = { workspace = true, features = ['std'] } wiggle = { workspace = true, features = ["wasmtime"] } # This dependency is necessary for the WIT-generation macros to work: -wasmtime = { workspace = true, features = ["component-model", "runtime"] } +wasmtime = { workspace = true, features = [ + "component-model", + "runtime", + "std", +] } # These dependencies are necessary for the wasi-nn implementation: tracing = { workspace = true } @@ -29,7 +33,7 @@ openvino = { version = "0.6.0", features = [ "runtime-linking", ], optional = true } -ort = { version = "2.0.0-rc.0", default-features = false, features = [ +ort = { version = "2.0.0-rc.2", default-features = false, features = [ "copy-dylibs", "download-binaries", ], optional = true } @@ -46,16 +50,17 @@ walkdir = { workspace = true } cap-std = { workspace = true } libtest-mimic = { workspace = true } test-programs-artifacts = { workspace = true } -wasi-common = { workspace = true, features = ["sync"] } +wasmtime-wasi = { workspace = true, features = ["preview1"] } wasmtime = { workspace = true, features = ["cranelift"] } +tracing-subscriber = { workspace = true } [features] default = ["openvino", "winml"] -# openvino is available on all platforms, it requires openvino installed. +# OpenVINO is available on all platforms; it requires OpenVINO to be installed. openvino = ["dep:openvino"] -# onnx is available on all platforms. +# ONNX is available on all platforms. onnx = ["dep:ort"] -# winml is only available on Windows 10 1809 and later. +# WinML is only available on Windows 10 1809 and later. winml = ["dep:windows"] [[test]] diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index e06f9e364925..710c93b980f4 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -3,20 +3,20 @@ //! implementations to maintain backend-specific state between calls. #[cfg(feature = "onnx")] -pub mod onnxruntime; +pub mod onnx; #[cfg(feature = "openvino")] pub mod openvino; #[cfg(all(feature = "winml", target_os = "windows"))] pub mod winml; #[cfg(feature = "onnx")] -use self::onnxruntime::OnnxBackend; +use self::onnx::OnnxBackend; #[cfg(feature = "openvino")] use self::openvino::OpenvinoBackend; #[cfg(all(feature = "winml", target_os = "windows"))] use self::winml::WinMLBackend; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; +use crate::wit::{ExecutionTarget, GraphEncoding, Tensor}; use crate::{Backend, ExecutionContext, Graph}; use std::fs::File; use std::io::Read; @@ -69,9 +69,30 @@ pub trait BackendGraph: Send + Sync { /// A [BackendExecutionContext] performs the actual inference; this is the /// backing implementation for a user-facing execution context. pub trait BackendExecutionContext: Send + Sync { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>; + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>; fn compute(&mut self) -> Result<(), BackendError>; - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result; + fn get_output(&mut self, id: Id) -> Result; +} + +/// An identifier for a tensor in a [Graph]. +#[derive(Debug)] +pub enum Id { + Index(u32), + Name(String), +} +impl Id { + pub fn index(&self) -> Option { + match self { + Id::Index(i) => Some(*i), + Id::Name(_) => None, + } + } + pub fn name(&self) -> Option<&str> { + match self { + Id::Index(_) => None, + Id::Name(n) => Some(n), + } + } } /// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all diff --git a/crates/wasi-nn/src/backend/onnx.rs b/crates/wasi-nn/src/backend/onnx.rs new file mode 100644 index 000000000000..b6ae2ddf278e --- /dev/null +++ b/crates/wasi-nn/src/backend/onnx.rs @@ -0,0 +1,338 @@ +//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate. + +use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; +use crate::backend::{read, Id}; +use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; +use crate::{ExecutionContext, Graph}; +use anyhow::Context; +use ort::{inputs, GraphOptimizationLevel, Session}; +use std::path::Path; +use std::sync::{Arc, Mutex}; + +#[derive(Default)] +pub struct OnnxBackend(); +unsafe impl Send for OnnxBackend {} +unsafe impl Sync for OnnxBackend {} + +impl BackendInner for OnnxBackend { + fn encoding(&self) -> GraphEncoding { + GraphEncoding::Onnx + } + + fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result { + if builders.len() != 1 { + return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); + } + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .commit_from_memory(builders[0])?; + + let box_: Box = + Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target)); + Ok(box_.into()) + } + + fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> { + Some(self) + } +} + +impl BackendFromDir for OnnxBackend { + fn load_from_dir( + &mut self, + path: &Path, + target: ExecutionTarget, + ) -> Result { + let model = read(&path.join("model.onnx"))?; + self.load(&[&model], target) + } +} + +struct OnnxGraph(Arc>, #[allow(dead_code)] ExecutionTarget); +unsafe impl Send for OnnxGraph {} +unsafe impl Sync for OnnxGraph {} + +impl BackendGraph for OnnxGraph { + fn init_execution_context(&self) -> Result { + let session = self.0.lock().unwrap(); + // We need to hold on to the names of the inputs in order for + // `set_input` to work with both indexes and names. Having the + // dimensions and type around is useful for validation but could be + // retrieved from the session. + let mut inputs = vec![]; + for input in &session.inputs { + let shape = Shape::from_onnx_input(input)?; + inputs.push(TensorSlot { + shape, + tensor: None, + }); + } + // We need to keep track of the output shapes since they are used for + // creating the output tensor. + let mut outputs = vec![]; + for output in &session.outputs { + let shape = Shape::from_onnx_output(output)?; + outputs.push(TensorSlot { + shape, + tensor: None, + }); + } + let box_: Box = Box::new(OnnxExecutionContext { + session: self.0.clone(), + inputs, + outputs, + }); + Ok(box_.into()) + } +} + +struct OnnxExecutionContext { + session: Arc>, + inputs: Vec, + outputs: Vec, +} + +unsafe impl Send for OnnxExecutionContext {} +unsafe impl Sync for OnnxExecutionContext {} + +impl OnnxExecutionContext { + /// Helper function for finding the internal index of a tensor by [`Id`]. + fn find(&self, id: Id, list: &[TensorSlot]) -> Result { + let index = match id { + Id::Index(i) => { + let i = i as usize; + if i < list.len() { + i + } else { + return Err(BackendError::BackendAccess(anyhow::anyhow!( + "incorrect tensor index: {i} >= {}", + list.len() + ))); + } + } + Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| { + BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}")) + })?, + }; + Ok(index) + } +} + +impl BackendExecutionContext for OnnxExecutionContext { + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + let index = self.find(id, &self.inputs)?; + let input = &mut self.inputs[index]; + if let Err(e) = input.shape.matches(tensor) { + return Err(e.into()); + } + // Hold the tensor data on the context until `compute` is called. + input.tensor.replace(tensor.clone()); + Ok(()) + } + + fn compute(&mut self) -> Result<(), BackendError> { + let mut session_inputs: Vec> = vec![]; + for i in &self.inputs { + session_inputs.extend(to_input_value(i)?); + } + let session = self.session.lock().unwrap(); + let session_outputs = session.run(session_inputs.as_slice())?; + for i in 0..self.outputs.len() { + // TODO: fix preexisting gap--this only handles f32 tensors. + let raw: (Vec, &[f32]) = session_outputs[i].try_extract_raw_tensor()?; + let f32s = raw.1.to_vec(); + let output = &mut self.outputs[i]; + output.tensor.replace(Tensor { + dimensions: output.shape.dimensions_as_u32()?, + ty: output.shape.ty, + data: f32_vec_to_bytes(f32s), + }); + } + Ok(()) + } + + fn get_output(&mut self, id: Id) -> Result { + let index = self.find(id, &self.outputs)?; + let output = &self.outputs[index]; + if let Some(tensor) = &output.tensor { + Ok(tensor.clone()) + } else { + Err(BackendError::BackendAccess(anyhow::anyhow!( + "missing output tensor: {}; has `compute` been called?", + output.shape.name + ))) + } + } +} + +impl From for BackendError { + fn from(e: ort::Error) -> Self { + BackendError::BackendAccess(e.into()) + } +} + +/// Holds a slot for ONNX session inputs and outputs. +/// +/// TODO: it seems unfortunate that we have to "hold" some extra data per +/// session but in the input case, this is necessary for name-based indexing. +struct TensorSlot { + shape: Shape, + tensor: Option, +} + +/// Describes a tensor in ONNX terms. +struct Shape { + name: String, + dimensions: Vec, + ty: TensorType, +} + +impl Shape { + fn from_onnx_input(input: &ort::Input) -> Result { + let name = input.name.clone(); + let (dimensions, ty) = convert_value_type(&input.input_type)?; + Ok(Self { + name, + dimensions, + ty, + }) + } + + fn from_onnx_output(output: &ort::Output) -> Result { + let name = output.name.clone(); + let (dimensions, ty) = convert_value_type(&output.output_type)?; + Ok(Self { + name, + dimensions, + ty, + }) + } + + fn dimensions_as_u32(&self) -> Result, BackendError> { + self.dimensions + .iter() + .map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) }) + .collect() + } + + fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> { + if self.dimensions.len() != tensor.dimensions.len() { + return Err(anyhow::anyhow!( + "input tensor cardinality does not match model: {:?} != {:?}", + self.dimensions, + tensor.dimensions + )); + } else { + for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) { + let tensor_dim = tensor_dim as i64; + if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim { + return Err(anyhow::anyhow!( + "input tensor dimensions do not match model: {:?} != {:?}", + self.dimensions, + tensor.dimensions + )); + } + } + } + if self.ty != tensor.ty { + return Err(anyhow::anyhow!( + "input tensor type does not match model: {:?} != {:?}", + self.ty, + tensor.ty + )); + } + Ok(()) + } +} + +fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec, TensorType), BackendError> { + match vt { + ort::ValueType::Tensor { ty, dimensions } => { + let dims = dimensions.clone(); + let ty = (*ty).try_into()?; + Ok((dims, ty)) + } + _ => Err(BackendError::BackendAccess(anyhow::anyhow!( + "unsupported input type: {vt:?}" + ))), + } +} + +fn convert_i64(i: &i64) -> Result { + u32::try_from(*i).map_err(|d| -> BackendError { + anyhow::anyhow!("unable to convert dimension to u32: {d}").into() + }) +} + +impl TryFrom for TensorType { + type Error = BackendError; + fn try_from(ty: ort::TensorElementType) -> Result { + match ty { + ort::TensorElementType::Float32 => Ok(TensorType::Fp32), + ort::TensorElementType::Float64 => Ok(TensorType::Fp64), + ort::TensorElementType::Uint8 => Ok(TensorType::U8), + ort::TensorElementType::Int32 => Ok(TensorType::I32), + ort::TensorElementType::Int64 => Ok(TensorType::I64), + _ => Err(BackendError::BackendAccess(anyhow::anyhow!( + "unsupported tensor type: {ty:?}" + ))), + } + } +} + +fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> { + match &slot.tensor { + Some(tensor) => match tensor.ty { + TensorType::Fp32 => { + let data = bytes_to_f32_vec(tensor.data.to_vec()); + let dimensions = tensor + .dimensions + .iter() + .map(|d| *d as i64) // TODO: fewer conversions + .collect::>(); + Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))] + .context("failed to create ONNX session input")?) + } + _ => { + unimplemented!("{:?} not supported by ONNX", tensor.ty); + } + }, + None => { + return Err(BackendError::BackendAccess(anyhow::anyhow!( + "missing input tensor: {}", + slot.shape.name + ))); + } + } +} + +pub fn f32_vec_to_bytes(data: Vec) -> Vec { + let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect(); + let result: Vec = chunks.iter().flatten().copied().collect(); + result +} + +pub fn bytes_to_f32_vec(data: Vec) -> Vec { + let chunks: Vec<&[u8]> = data.chunks(4).collect(); + let v: Vec = chunks + .into_iter() + .map(|c| f32::from_le_bytes(c.try_into().unwrap())) + .collect(); + + v.into_iter().collect() +} + +/// Returns whether the dimension is dynamic. +/// +/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the +/// value of a tensor dimension is user-defined, not fixed by the model. This is +/// useful for batching up several inference requests, e.g. When `ort` returns a +/// dimension of this kind, though, it uses `-1` to indicate that the dimension +/// is dynamic. +/// +/// [dimensional variables]: +/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes +fn is_dynamic_dimension(d: i64) -> bool { + d == -1 +} diff --git a/crates/wasi-nn/src/backend/onnxruntime.rs b/crates/wasi-nn/src/backend/onnxruntime.rs deleted file mode 100644 index bddb03dc9b35..000000000000 --- a/crates/wasi-nn/src/backend/onnxruntime.rs +++ /dev/null @@ -1,149 +0,0 @@ -//! Implements a `wasi-nn` [`BackendInner`] using ONNX via ort. - -use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; -use crate::backend::read; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; -use crate::{ExecutionContext, Graph}; -use ort::{inputs, GraphOptimizationLevel, Session}; -use std::path::Path; -use std::sync::{Arc, Mutex}; - -#[derive(Default)] -pub struct OnnxBackend(); -unsafe impl Send for OnnxBackend {} -unsafe impl Sync for OnnxBackend {} - -impl BackendInner for OnnxBackend { - fn encoding(&self) -> GraphEncoding { - GraphEncoding::Onnx - } - - fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result { - if builders.len() != 1 { - return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); - } - - let session = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_memory(builders[0])?; - - let box_: Box = - Box::new(ONNXGraph(Arc::new(Mutex::new(session)), target)); - Ok(box_.into()) - } - - fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> { - Some(self) - } -} - -impl BackendFromDir for OnnxBackend { - fn load_from_dir( - &mut self, - path: &Path, - target: ExecutionTarget, - ) -> Result { - let model = read(&path.join("model.onnx"))?; - self.load(&[&model], target) - } -} - -struct ONNXGraph(Arc>, #[allow(dead_code)] ExecutionTarget); - -unsafe impl Send for ONNXGraph {} -unsafe impl Sync for ONNXGraph {} - -impl BackendGraph for ONNXGraph { - fn init_execution_context(&self) -> Result { - let session = self.0.lock().unwrap(); - let inputs = session.inputs.iter().map(|_| None).collect::>(); - let outputs = session.outputs.iter().map(|_| None).collect::>(); - let box_: Box = Box::new(ONNXExecutionContext { - session: self.0.clone(), - inputs, - outputs, - }); - Ok(box_.into()) - } -} - -struct ONNXExecutionContext { - session: Arc>, - inputs: Vec>, - outputs: Vec>>, -} - -unsafe impl Send for ONNXExecutionContext {} -unsafe impl Sync for ONNXExecutionContext {} - -impl BackendExecutionContext for ONNXExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { - self.inputs[index as usize].replace(tensor.clone()); - Ok(()) - } - - fn compute(&mut self) -> Result<(), BackendError> { - let shaped_inputs: Vec<_> = self - .inputs - .iter() - .enumerate() - .map(|(i, _o)| { - let input = self.inputs[i].as_ref().unwrap(); - let dims = input - .dimensions - .as_slice() - .iter() - .map(|d| *d as i64) - .collect::>(); - match input.tensor_type { - TensorType::Fp32 => { - let data = bytes_to_f32_vec(input.data.to_vec()); - inputs![(dims, Arc::new(data.into_boxed_slice()))].unwrap() - } - _ => { - unimplemented!("{:?} not supported by ONNX", input.tensor_type); - } - } - }) - .flatten() - .collect(); - - let session = self.session.lock().unwrap(); - let res = session.run(shaped_inputs.as_slice())?; - - for i in 0..self.outputs.len() { - let raw: (Vec, &[f32]) = res[i].extract_raw_tensor()?; - let f32s = raw.1.to_vec(); - self.outputs[i].replace(f32_vec_to_bytes(f32s)); - } - Ok(()) - } - - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { - let output = self.outputs[index as usize].as_ref().unwrap(); - destination[..output.len()].copy_from_slice(output); - Ok(output.len() as u32) - } -} - -impl From for BackendError { - fn from(e: ort::Error) -> Self { - BackendError::BackendAccess(e.into()) - } -} - -pub fn f32_vec_to_bytes(data: Vec) -> Vec { - let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect(); - let result: Vec = chunks.iter().flatten().copied().collect(); - result -} - -pub fn bytes_to_f32_vec(data: Vec) -> Vec { - let chunks: Vec<&[u8]> = data.chunks(4).collect(); - let v: Vec = chunks - .into_iter() - .map(|c| f32::from_le_bytes(c.try_into().unwrap())) - .collect(); - - v.into_iter().collect() -} diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index 65c96629eead..b24f93838cab 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,9 +1,9 @@ //! Implements a `wasi-nn` [`BackendInner`] using OpenVINO. use super::{ - read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, + read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id, }; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; +use crate::wit::{self, ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use std::path::Path; @@ -99,12 +99,15 @@ impl BackendGraph for OpenvinoGraph { struct OpenvinoExecutionContext(Arc, openvino::InferRequest); impl BackendExecutionContext for OpenvinoExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { - let input_name = self.0.get_input_name(index as usize)?; + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + let input_name = match id { + Id::Index(i) => self.0.get_input_name(i as usize)?, + Id::Name(name) => name, + }; // Construct the blob structure. TODO: there must be some good way to // discover the layout here; `desc` should not have to default to NHWC. - let precision = map_tensor_type_to_precision(tensor.tensor_type); + let precision = map_tensor_type_to_precision(tensor.ty); let dimensions = tensor .dimensions .iter() @@ -123,17 +126,20 @@ impl BackendExecutionContext for OpenvinoExecutionContext { Ok(()) } - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { - let output_name = self.0.get_output_name(index as usize)?; + fn get_output(&mut self, id: Id) -> Result { + let output_name = match id { + Id::Index(i) => self.0.get_output_name(i as usize)?, + Id::Name(name) => name, + }; + let dimensions = vec![]; // TODO: get actual shape + let ty = wit::TensorType::Fp32; // TODO: get actual type. let blob = self.1.get_blob(&output_name)?; - let blob_size = blob.byte_len()?; - if blob_size > destination.len() { - return Err(BackendError::NotEnoughMemory(blob_size)); - } - - // Copy the tensor data into the destination buffer. - destination[..blob_size].copy_from_slice(blob.buffer()?); - Ok(blob_size as u32) + let data = blob.buffer()?.to_vec(); + Ok(Tensor { + dimensions, + ty, + data, + }) } } diff --git a/crates/wasi-nn/src/backend/winml.rs b/crates/wasi-nn/src/backend/winml.rs index e11761f86732..b87510bfc877 100644 --- a/crates/wasi-nn/src/backend/winml.rs +++ b/crates/wasi-nn/src/backend/winml.rs @@ -1,16 +1,27 @@ //! Implements a `wasi-nn` [`BackendInner`] using WinML. - -use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; +//! +//! Note that the [docs.rs] documentation for the `windows` crate does have the +//! right features turned on to read about the functions used; see Microsoft's +//! private documentation instead: [microsoft.github.io/windows-docs-rs]. +//! +//! [docs.rs]: https://docs.rs/windows +//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning + +use crate::backend::{ + BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id, +}; +use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use std::{fs::File, io::Read, mem::size_of, path::Path}; use windows::core::{ComInterface, HSTRING}; +use windows::Foundation::Collections::IVectorView; use windows::Storage::Streams::{ DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference, }; use windows::AI::MachineLearning::{ - LearningModel, LearningModelBinding, LearningModelDevice, LearningModelDeviceKind, - LearningModelEvaluationResult, LearningModelSession, TensorFeatureDescriptor, TensorFloat, + ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice, + LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession, + TensorFeatureDescriptor, TensorFloat, }; #[derive(Default)] @@ -94,29 +105,64 @@ impl WinMLExecutionContext { } } +impl WinMLExecutionContext { + /// Helper function for finding the internal index of a tensor by [`Id`]. + fn find( + &self, + id: Id, + list: &IVectorView, + ) -> Result { + let index = match id { + Id::Index(i) => { + if i < list.Size()? { + i + } else { + return Err(BackendError::BackendAccess(anyhow::anyhow!( + "incorrect tensor index: {i} >= {}", + list.Size()? + ))); + } + } + Id::Name(name) => list + .into_iter() + .position(|d| d.Name().unwrap() == name) + .ok_or_else(|| { + BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {name}")) + })? as u32, + }; + Ok(index) + } +} + impl BackendExecutionContext for WinMLExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + let input_features = self.session.Model()?.InputFeatures()?; + let index = self.find(id, &input_features)?; + let input = input_features.GetAt(index)?; + // TODO: Support other tensor types. Only FP32 is supported right now. - match tensor.tensor_type { + match tensor.ty { crate::wit::types::TensorType::Fp32 => {} _ => unimplemented!(), } - let input = self.session.Model()?.InputFeatures()?.GetAt(index)?; - unsafe { - let data = std::slice::from_raw_parts( + // TODO: this is quite unsafe and probably incorrect--will the slice + // still be around by the time the binding is used?! + let data = unsafe { + std::slice::from_raw_parts( tensor.data.as_ptr() as *const f32, - tensor.data.len() / 4, - ); - - self.binding.Bind( - &input.Name()?, - &TensorFloat::CreateFromArray( - &input.cast::()?.Shape()?, - data, - )?, - )?; - } + tensor.data.len() / size_of::(), + ) + }; + + self.binding.Bind( + &input.Name()?, + &TensorFloat::CreateFromArray( + &input.cast::()?.Shape()?, + data, + )?, + )?; + Ok(()) } @@ -125,33 +171,32 @@ impl BackendExecutionContext for WinMLExecutionContext { Ok(()) } - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { - if self.result.is_none() { + fn get_output(&mut self, id: Id) -> Result { + if let Some(result) = &self.result { + let output_features = self.session.Model()?.OutputFeatures()?; + let index = self.find(id, &output_features)?; + let output = output_features.GetAt(index)?; + // TODO: this only handles FP32! + let tensor = result + .Outputs()? + .Lookup(&output.Name()?)? + .cast::()?; + let dimensions = dimensions_as_u32(&tensor.Shape()?)?; + let view = tensor.GetAsVectorView()?; + let mut data = Vec::with_capacity(view.Size()? as usize * size_of::()); + for f in view.into_iter() { + data.extend(f.to_le_bytes()); + } + Ok(Tensor { + ty: TensorType::Fp32, + dimensions, + data, + }) + } else { return Err(BackendError::BackendAccess(anyhow::Error::msg( "Output is not ready.", ))); } - let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?; - let output_name_hstring = output_name.Name()?; - - let vector_view = self - .result - .as_ref() - .unwrap() - .Outputs()? - .Lookup(&output_name_hstring)? - .cast::()? - .GetAsVectorView()?; - let output: Vec = vector_view.into_iter().collect(); - let len_to_copy = output.len() * size_of::(); - unsafe { - destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts( - output.as_ptr() as *const u8, - len_to_copy, - )); - } - - Ok(len_to_copy as u32) } } @@ -168,3 +213,16 @@ impl From for BackendError { BackendError::BackendAccess(anyhow::Error::new(e)) } } + +fn dimensions_as_u32(dimensions: &IVectorView) -> Result, BackendError> { + dimensions + .into_iter() + .map(|d| if d == -1 { Ok(1) } else { convert_i64(d) }) + .collect() +} + +fn convert_i64(i: i64) -> Result { + u32::try_from(i).map_err(|d| -> BackendError { + anyhow::anyhow!("unable to convert dimension to u32: {d}").into() + }) +} diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs deleted file mode 100644 index 40cfb9d53d2d..000000000000 --- a/crates/wasi-nn/src/ctx.rs +++ /dev/null @@ -1,146 +0,0 @@ -//! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. - -use crate::backend::{self, BackendError}; -use crate::wit::types::GraphEncoding; -use crate::{Backend, ExecutionContext, Graph, InMemoryRegistry, Registry}; -use anyhow::anyhow; -use std::{collections::HashMap, hash::Hash, path::Path}; -use thiserror::Error; -use wiggle::GuestError; - -type GraphId = u32; -type GraphExecutionContextId = u32; -type BackendName = String; -type GraphDirectory = String; - -/// Construct an in-memory registry from the available backends and a list of -/// `(, )`. This assumes graphs can be loaded -/// from a local directory, which is a safe assumption currently for the current -/// model types. -pub fn preload( - preload_graphs: &[(BackendName, GraphDirectory)], -) -> anyhow::Result<(impl IntoIterator, Registry)> { - let mut backends = backend::list(); - let mut registry = InMemoryRegistry::new(); - for (kind, path) in preload_graphs { - let kind_ = kind.parse()?; - let backend = backends - .iter_mut() - .find(|b| b.encoding() == kind_) - .ok_or(anyhow!("unsupported backend: {}", kind))? - .as_dir_loadable() - .ok_or(anyhow!("{} does not support directory loading", kind))?; - registry.load(backend, Path::new(path))?; - } - Ok((backends, Registry::from(registry))) -} - -/// Capture the state necessary for calling into the backend ML libraries. -pub struct WasiNnCtx { - pub(crate) backends: HashMap, - pub(crate) registry: Registry, - pub(crate) graphs: Table, - pub(crate) executions: Table, -} - -impl WasiNnCtx { - /// Make a new context from the default state. - pub fn new(backends: impl IntoIterator, registry: Registry) -> Self { - let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); - Self { - backends, - registry, - graphs: Table::default(), - executions: Table::default(), - } - } -} - -/// Possible errors while interacting with [WasiNnCtx]. -#[derive(Debug, Error)] -pub enum WasiNnError { - #[error("backend error")] - BackendError(#[from] BackendError), - #[error("guest error")] - GuestError(#[from] GuestError), - #[error("usage error")] - UsageError(#[from] UsageError), -} - -#[derive(Debug, Error)] -pub enum UsageError { - #[error("Invalid context; has the load function been called?")] - InvalidContext, - #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] - InvalidEncoding(GraphEncoding), - #[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] - InvalidNumberOfBuilders(u32), - #[error("Invalid graph handle; has it been loaded?")] - InvalidGraphHandle, - #[error("Invalid execution context handle; has it been initialized?")] - InvalidExecutionContextHandle, - #[error("Not enough memory to copy tensor data of size: {0}")] - NotEnoughMemory(u32), - #[error("No graph found with name: {0}")] - NotFound(String), -} - -pub(crate) type WasiNnResult = std::result::Result; - -/// Record handle entries in a table. -pub struct Table { - entries: HashMap, - next_key: u32, -} - -impl Default for Table { - fn default() -> Self { - Self { - entries: HashMap::new(), - next_key: 0, - } - } -} - -impl Table -where - K: Eq + Hash + From + Copy, -{ - pub fn insert(&mut self, value: V) -> K { - let key = self.use_next_key(); - self.entries.insert(key, value); - key - } - - pub fn get(&self, key: K) -> Option<&V> { - self.entries.get(&key) - } - - pub fn get_mut(&mut self, key: K) -> Option<&mut V> { - self.entries.get_mut(&key) - } - - fn use_next_key(&mut self) -> K { - let current = self.next_key; - self.next_key += 1; - K::from(current) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::registry::GraphRegistry; - - #[test] - fn example() { - struct FakeRegistry; - impl GraphRegistry for FakeRegistry { - fn get_mut(&mut self, _: &str) -> Option<&mut Graph> { - None - } - } - - let _ctx = WasiNnCtx::new([], Registry::from(FakeRegistry)); - } -} diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 71d089d07489..a9f86f2da5a2 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,14 +1,34 @@ -mod ctx; -mod registry; - pub mod backend; -pub use ctx::{preload, WasiNnCtx}; -pub use registry::{GraphRegistry, InMemoryRegistry}; +mod registry; pub mod wit; pub mod witx; +use anyhow::anyhow; +use core::fmt; +pub use registry::{GraphRegistry, InMemoryRegistry}; +use std::path::Path; use std::sync::Arc; +/// Construct an in-memory registry from the available backends and a list of +/// `(, )`. This assumes graphs can be loaded +/// from a local directory, which is a safe assumption currently for the current +/// model types. +pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec, Registry)> { + let mut backends = backend::list(); + let mut registry = InMemoryRegistry::new(); + for (kind, path) in preload_graphs { + let kind_ = kind.parse()?; + let backend = backends + .iter_mut() + .find(|b| b.encoding() == kind_) + .ok_or(anyhow!("unsupported backend: {}", kind))? + .as_dir_loadable() + .ok_or(anyhow!("{} does not support directory loading", kind))?; + registry.load(backend, Path::new(path))?; + } + Ok((backends, Registry::from(registry))) +} + /// A machine learning backend. pub struct Backend(Box); impl std::ops::Deref for Backend { @@ -43,6 +63,27 @@ impl std::ops::Deref for Graph { } } +/// A host-side tensor. +/// +/// Eventually, this may be defined in each backend as they gain the ability to +/// hold tensors on various devices (TODO: +/// https://github.com/WebAssembly/wasi-nn/pull/70). +#[derive(Clone)] +pub struct Tensor { + dimensions: Vec, + ty: wit::TensorType, + data: Vec, +} +impl fmt::Debug for Tensor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Tensor") + .field("dimensions", &self.dimensions) + .field("ty", &self.ty) + .field("data (bytes)", &self.data.len()) + .finish() + } +} + /// A backend-defined execution context. pub struct ExecutionContext(Box); impl From> for ExecutionContext { diff --git a/crates/wasi-nn/src/registry/in_memory.rs b/crates/wasi-nn/src/registry/in_memory.rs index b008f7f43684..86f81203d646 100644 --- a/crates/wasi-nn/src/registry/in_memory.rs +++ b/crates/wasi-nn/src/registry/in_memory.rs @@ -2,7 +2,7 @@ use super::{Graph, GraphRegistry}; use crate::backend::BackendFromDir; -use crate::wit::types::ExecutionTarget; +use crate::wit::ExecutionTarget; use anyhow::{anyhow, bail}; use std::{collections::HashMap, path::Path}; @@ -37,6 +37,9 @@ impl InMemoryRegistry { } impl GraphRegistry for InMemoryRegistry { + fn get(&self, name: &str) -> Option<&Graph> { + self.0.get(name) + } fn get_mut(&mut self, name: &str) -> Option<&mut Graph> { self.0.get_mut(name) } diff --git a/crates/wasi-nn/src/registry/mod.rs b/crates/wasi-nn/src/registry/mod.rs index 83f88e4dca0e..5f4d959132dc 100644 --- a/crates/wasi-nn/src/registry/mod.rs +++ b/crates/wasi-nn/src/registry/mod.rs @@ -12,5 +12,6 @@ use crate::Graph; pub use in_memory::InMemoryRegistry; pub trait GraphRegistry: Send + Sync { + fn get(&self, name: &str) -> Option<&Graph>; fn get_mut(&mut self, name: &str) -> Option<&mut Graph>; } diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index dbe894357cdf..40f6fc4c1ff6 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -15,8 +15,69 @@ //! [`Backend`]: crate::Backend //! [`types`]: crate::wit::types -use crate::{ctx::UsageError, WasiNnCtx}; -use std::{error::Error, fmt, hash::Hash, str::FromStr}; +use crate::backend::Id; +use crate::{Backend, Registry}; +use std::collections::HashMap; +use std::hash::Hash; +use std::{fmt, str::FromStr}; +use wasmtime::component::{Resource, ResourceTable}; + +/// Capture the state necessary for calling into the backend ML libraries. +pub struct WasiNnCtx { + pub(crate) backends: HashMap, + pub(crate) registry: Registry, +} + +impl WasiNnCtx { + /// Make a new context from the default state. + pub fn new(backends: impl IntoIterator, registry: Registry) -> Self { + let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); + Self { backends, registry } + } +} + +/// A wrapper capturing the needed internal wasi-nn state. +/// +/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`), +/// this wrapper is not a `trait` but rather holds the references directly. This +/// remove one layer of abstraction for simplicity only, and could be added back +/// in the future if embedders need more control here. +pub struct WasiNnView<'a> { + ctx: &'a mut WasiNnCtx, + table: &'a mut ResourceTable, +} + +impl<'a> WasiNnView<'a> { + /// Create a new view into the wasi-nn state. + pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self { + Self { ctx, table } + } +} + +pub enum Error { + /// Caller module passed an invalid argument. + InvalidArgument, + /// Invalid encoding. + InvalidEncoding, + /// The operation timed out. + Timeout, + /// Runtime Error. + RuntimeError, + /// Unsupported operation. + UnsupportedOperation, + /// Graph is too large. + TooLarge, + /// Graph not found. + NotFound, + /// A runtime error occurred that we should trap on; see `StreamError`. + Trap(anyhow::Error), +} + +impl From for Error { + fn from(error: wasmtime::component::ResourceTableError) -> Self { + Self::Trap(error.into()) + } +} /// Generate the traits and types from the `wasi-nn` WIT specification. mod gen_ { @@ -24,126 +85,241 @@ mod gen_ { world: "ml", path: "wit/wasi-nn.wit", trappable_imports: true, + with: { + // Configure all WIT http resources to be defined types in this + // crate to use the `ResourceTable` helper methods. + "wasi:nn/graph/graph": crate::Graph, + "wasi:nn/tensor/tensor": crate::Tensor, + "wasi:nn/inference/graph-execution-context": crate::ExecutionContext, + }, + trappable_error_type: { + "wasi:nn/errors/error" => super::Error, + }, }); } -use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need. +use gen_::wasi::nn::{self as gen}; // Shortcut to the module containing the types we need. // Export the `types` used in this crate as well as `ML::add_to_linker`. pub mod types { use super::gen; - pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding}; + pub use gen::errors::Error; + pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding}; pub use gen::inference::GraphExecutionContext; pub use gen::tensor::{Tensor, TensorType}; } +pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding}; +pub use gen::inference::GraphExecutionContext; +pub use gen::tensor::{Tensor, TensorData, TensorDimensions, TensorType}; pub use gen_::Ml as ML; -impl gen::graph::Host for WasiNnCtx { - /// Load an opaque sequence of bytes to use for inference. +/// Add the WIT-based version of the `wasi-nn` API to a +/// [`wasmtime::component::Linker`]. +pub fn add_to_linker( + l: &mut wasmtime::component::Linker, + f: impl Fn(&mut T) -> WasiNnView<'_> + Send + Sync + Copy + 'static, +) -> anyhow::Result<()> { + gen::graph::add_to_linker_get_host(l, f)?; + gen::tensor::add_to_linker_get_host(l, f)?; + gen::inference::add_to_linker_get_host(l, f)?; + gen::errors::add_to_linker_get_host(l, f)?; + Ok(()) +} + +impl gen::graph::Host for WasiNnView<'_> { fn load( &mut self, - builders: Vec, - encoding: gen::graph::GraphEncoding, - target: gen::graph::ExecutionTarget, - ) -> wasmtime::Result> { - let graph = if let Some(backend) = self.backends.get_mut(&encoding) { + builders: Vec, + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result, Error> { + tracing::debug!("load {encoding:?} {target:?}"); + if let Some(backend) = self.ctx.backends.get_mut(&encoding) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); - backend.load(&slices, target.into())? + match backend.load(&slices, target.into()) { + Ok(graph) => { + let graph = self.table.push(graph)?; + Ok(graph) + } + Err(error) => { + tracing::error!("failed to load graph: {error:?}"); + Err(Error::RuntimeError) + } + } } else { - return Err(UsageError::InvalidEncoding(encoding.into()).into()); - }; - let graph_id = self.graphs.insert(graph); - Ok(Ok(graph_id)) + Err(Error::InvalidEncoding) + } } - fn load_by_name( - &mut self, - name: String, - ) -> wasmtime::Result> { - if let Some(graph) = self.registry.get_mut(&name) { - let graph_id = self.graphs.insert(graph.clone().into()); - Ok(Ok(graph_id)) + fn load_by_name(&mut self, name: String) -> Result, Error> { + use core::result::Result::*; + tracing::debug!("load by name {name:?}"); + let registry = &self.ctx.registry; + if let Some(graph) = registry.get(&name) { + let graph = graph.clone(); + let graph = self.table.push(graph)?; + Ok(graph) } else { - return Err(UsageError::NotFound(name.to_string()).into()); + tracing::error!("failed to find graph with name: {name}"); + Err(Error::NotFound) } } } -impl gen::inference::Host for WasiNnCtx { - /// Create an execution instance of a loaded graph. - /// - /// TODO: remove completely? +impl gen::graph::HostGraph for WasiNnView<'_> { fn init_execution_context( &mut self, - graph_id: gen::graph::Graph, - ) -> wasmtime::Result> { - let exec_context = if let Some(graph) = self.graphs.get(graph_id) { - graph.init_execution_context()? - } else { - return Err(UsageError::InvalidGraphHandle.into()); - }; + graph: Resource, + ) -> Result, Error> { + use core::result::Result::*; + tracing::debug!("initialize execution context"); + let graph = self.table.get(&graph)?; + match graph.init_execution_context() { + Ok(exec_context) => { + let exec_context = self.table.push(exec_context)?; + Ok(exec_context) + } + Err(error) => { + tracing::error!("failed to initialize execution context: {error:?}"); + Err(Error::RuntimeError) + } + } + } - let exec_context_id = self.executions.insert(exec_context); - Ok(Ok(exec_context_id)) + fn drop(&mut self, graph: Resource) -> wasmtime::Result<()> { + self.table.delete(graph)?; + Ok(()) } +} - /// Define the inputs to use for inference. +impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> { fn set_input( &mut self, - exec_context_id: gen::inference::GraphExecutionContext, - index: u32, - tensor: gen::tensor::Tensor, - ) -> wasmtime::Result> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - exec_context.set_input(index, &tensor)?; - Ok(Ok(())) + exec_context: Resource, + name: String, + tensor: Resource, + ) -> Result<(), Error> { + let tensor = self.table.get(&tensor)?; + tracing::debug!("set input {name:?}: {tensor:?}"); + let tensor = tensor.clone(); // TODO: avoid copying the tensor + let exec_context = self.table.get_mut(&exec_context)?; + if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) { + tracing::error!("failed to set input: {e:?}"); + Err(Error::InvalidArgument) } else { - Err(UsageError::InvalidGraphHandle.into()) + Ok(()) } } - /// Compute the inference on the given inputs. - /// - /// TODO: refactor to compute(list) -> result, error> - fn compute( - &mut self, - exec_context_id: gen::inference::GraphExecutionContext, - ) -> wasmtime::Result> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - exec_context.compute()?; - Ok(Ok(())) - } else { - Err(UsageError::InvalidExecutionContextHandle.into()) + fn compute(&mut self, exec_context: Resource) -> Result<(), Error> { + let exec_context = &mut self.table.get_mut(&exec_context)?; + tracing::debug!("compute"); + match exec_context.compute() { + Ok(()) => Ok(()), + Err(error) => { + tracing::error!("failed to compute: {error:?}"); + Err(Error::RuntimeError) + } } } - /// Extract the outputs after inference. + #[doc = r" Extract the outputs after inference."] fn get_output( &mut self, - exec_context_id: gen::inference::GraphExecutionContext, - index: u32, - ) -> wasmtime::Result> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - // Read the output bytes. TODO: this involves a hard-coded upper - // limit on the tensor size that is necessary because there is no - // way to introspect the graph outputs - // (https://github.com/WebAssembly/wasi-nn/issues/37). - let mut destination = vec![0; 1024 * 1024]; - let bytes_read = exec_context.get_output(index, &mut destination)?; - destination.truncate(bytes_read as usize); - Ok(Ok(destination)) - } else { - Err(UsageError::InvalidGraphHandle.into()) + exec_context: Resource, + name: String, + ) -> Result, Error> { + let exec_context = self.table.get_mut(&exec_context)?; + tracing::debug!("get output {name:?}"); + match exec_context.get_output(Id::Name(name)) { + Ok(tensor) => { + let tensor = self.table.push(tensor)?; + Ok(tensor) + } + Err(error) => { + tracing::error!("failed to get output: {error:?}"); + Err(Error::RuntimeError) + } } } + + fn drop(&mut self, exec_context: Resource) -> wasmtime::Result<()> { + self.table.delete(exec_context)?; + Ok(()) + } } -impl gen::errors::Host for WasiNnCtx {} +impl gen::tensor::HostTensor for WasiNnView<'_> { + fn new( + &mut self, + dimensions: TensorDimensions, + ty: TensorType, + data: TensorData, + ) -> wasmtime::Result> { + let tensor = Tensor { + dimensions, + ty, + data, + }; + let tensor = self.table.push(tensor)?; + Ok(tensor) + } + + fn dimensions(&mut self, tensor: Resource) -> wasmtime::Result { + let tensor = self.table.get(&tensor)?; + Ok(tensor.dimensions.clone()) + } + + fn ty(&mut self, tensor: Resource) -> wasmtime::Result { + let tensor = self.table.get(&tensor)?; + Ok(tensor.ty) + } + + fn data(&mut self, tensor: Resource) -> wasmtime::Result { + let tensor = self.table.get(&tensor)?; + Ok(tensor.data.clone()) + } -impl gen::tensor::Host for WasiNnCtx {} + fn drop(&mut self, tensor: Resource) -> wasmtime::Result<()> { + self.table.delete(tensor)?; + Ok(()) + } +} + +impl gen::tensor::Host for WasiNnView<'_> {} +impl gen::errors::Host for WasiNnView<'_> { + fn convert_error(&mut self, err: Error) -> wasmtime::Result { + match err { + Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument), + Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding), + Error::Timeout => Ok(gen::errors::Error::Timeout), + Error::RuntimeError => Ok(gen::errors::Error::RuntimeError), + Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation), + Error::TooLarge => Ok(gen::errors::Error::TooLarge), + Error::NotFound => Ok(gen::errors::Error::NotFound), + Error::Trap(e) => Err(e), + } + } +} +impl gen::inference::Host for WasiNnView<'_> {} impl Hash for gen::graph::GraphEncoding { fn hash(&self, state: &mut H) { - core::mem::discriminant(self).hash(state); + self.to_string().hash(state) + } +} + +impl fmt::Display for gen::graph::GraphEncoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use gen::graph::GraphEncoding::*; + match self { + Openvino => write!(f, "openvino"), + Onnx => write!(f, "onnx"), + Pytorch => write!(f, "pytorch"), + Tensorflow => write!(f, "tensorflow"), + Tensorflowlite => write!(f, "tensorflowlite"), + Autodetect => write!(f, "autodetect"), + Ggml => write!(f, "ggml"), + } } } @@ -168,4 +344,4 @@ impl fmt::Display for GraphEncodingParseError { write!(f, "unknown graph encoding: {}", self.0) } } -impl Error for GraphEncodingParseError {} +impl std::error::Error for GraphEncodingParseError {} diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index 9d78a595ec90..f4f2eab90647 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -13,11 +13,83 @@ //! //! [`types`]: crate::wit::types -use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; -use wiggle::{GuestMemory, GuestPtr}; +use crate::backend::BackendError; +use crate::backend::Id; +use crate::wit::GraphEncoding; +use crate::{Backend, ExecutionContext, Graph, Registry}; +use std::collections::HashMap; +use std::hash::Hash; +use thiserror::Error; +use wiggle::{GuestError, GuestMemory, GuestPtr}; pub use gen::wasi_ephemeral_nn::add_to_linker; +pub(crate) type WasiNnResult = std::result::Result; +type Result = WasiNnResult; +type GraphId = u32; +type GraphExecutionContextId = u32; + +/// Capture the state necessary for calling into the backend ML libraries. +pub struct WasiNnCtx { + pub(crate) backends: HashMap, + pub(crate) registry: Registry, + pub(crate) graphs: Table, + pub(crate) executions: Table, +} + +impl WasiNnCtx { + /// Make a new context from the default state. + pub fn new(backends: impl IntoIterator, registry: Registry) -> Self { + let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); + Self { + backends, + registry, + graphs: Table::default(), + executions: Table::default(), + } + } +} + +/// Record handle entries in a table. +pub struct Table { + entries: HashMap, + next_key: u32, +} + +impl Default for Table { + fn default() -> Self { + Self { + entries: HashMap::new(), + next_key: 0, + } + } +} + +impl Table +where + K: Eq + Hash + From + Copy, +{ + pub fn insert(&mut self, value: V) -> K { + let key = self.use_next_key(); + self.entries.insert(key, value); + key + } + + pub fn get(&self, key: K) -> Option<&V> { + self.entries.get(&key) + } + + pub fn get_mut(&mut self, key: K) -> Option<&mut V> { + self.entries.get_mut(&key) + } + + fn use_next_key(&mut self) -> K { + let current = self.next_key; + self.next_key += 1; + K::from(current) + } +} + /// Generate the traits and types from the `wasi-nn` WITX specification. mod gen { use super::*; @@ -42,9 +114,10 @@ mod gen { ) -> anyhow::Result { tracing::debug!("host error: {:?}", e); match e { - WasiNnError::BackendError(_) => unimplemented!(), - WasiNnError::GuestError(_) => unimplemented!(), - WasiNnError::UsageError(_) => unimplemented!(), + WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError), + WasiNnError::GuestError(_) => unimplemented!("guest error conversion"), + WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation), + WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge), } } } @@ -119,10 +192,10 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { let tensor = crate::wit::types::Tensor { dimensions: memory.to_vec(tensor.dimensions)?, - tensor_type: tensor.type_.into(), + ty: tensor.type_.into(), data: memory.to_vec(tensor.data)?, }; - Ok(exec_context.set_input(index, &tensor)?) + Ok(exec_context.set_input(Id::Index(index), &tensor)?) } else { Err(UsageError::InvalidGraphHandle.into()) } @@ -149,13 +222,19 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { out_buffer_max_size: u32, ) -> Result { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { - let mut destination = memory + let tensor = exec_context.get_output(Id::Index(index))?; + let destination = memory .as_slice_mut(out_buffer.as_array(out_buffer_max_size))? .expect( "cannot use with shared memories; \ see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)", ); - Ok(exec_context.get_output(index, &mut destination)?) + if tensor.data.len() > destination.len() { + Err(WasiNnError::NotEnoughMemory(tensor.data.len())) + } else { + destination[..tensor.data.len()].copy_from_slice(&tensor.data); + Ok(tensor.data.len() as u32) + } } else { Err(UsageError::InvalidGraphHandle.into()) } @@ -199,3 +278,28 @@ impl From for crate::wit::types::TensorType { } } } + +/// Possible errors while interacting with [WasiNnCtx]. +#[derive(Debug, Error)] +pub enum WasiNnError { + #[error("backend error")] + BackendError(#[from] BackendError), + #[error("guest error")] + GuestError(#[from] GuestError), + #[error("usage error")] + UsageError(#[from] UsageError), + #[error("not enough memory: requested {0} bytes")] + NotEnoughMemory(usize), +} + +#[derive(Debug, Error)] +pub enum UsageError { + #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] + InvalidEncoding(GraphEncoding), + #[error("Invalid graph handle; has it been loaded?")] + InvalidGraphHandle, + #[error("Invalid execution context handle; has it been initialized?")] + InvalidExecutionContextHandle, + #[error("No graph found with name: {0}")] + NotFound(String), +} diff --git a/crates/wasi-nn/tests/check/mod.rs b/crates/wasi-nn/tests/check/mod.rs index ffdc099b3009..9f59d38130af 100644 --- a/crates/wasi-nn/tests/check/mod.rs +++ b/crates/wasi-nn/tests/check/mod.rs @@ -1,10 +1,8 @@ -//! This is testing-specific code--it is public only so that it can be -//! accessible both in unit and integration tests. +//! Check that the environment is set up correctly for running tests. //! //! This module checks: -//! - that OpenVINO can be found in the environment -//! - that WinML is available -//! - that some ML model artifacts can be downloaded and cached. +//! - that various backends can be located on the system (see sub-modules) +//! - that certain ML model artifacts can be downloaded and cached. #[allow(unused_imports)] use anyhow::{anyhow, Context, Result}; diff --git a/crates/wasi-nn/tests/exec/mod.rs b/crates/wasi-nn/tests/exec/mod.rs index 23840e7a5d3a..5fd25cbd3cbf 100644 --- a/crates/wasi-nn/tests/exec/mod.rs +++ b/crates/wasi-nn/tests/exec/mod.rs @@ -1,52 +1,6 @@ -use crate::check::artifacts_dir; -use anyhow::Result; -use std::path::Path; -use wasi_common::sync::{Dir, WasiCtxBuilder}; -use wasi_common::WasiCtx; -use wasmtime::{Config, Engine, Linker, Module, Store}; -use wasmtime_wasi_nn::{Backend, InMemoryRegistry, WasiNnCtx}; +//! Provide a Wasmtime embedding for executing wasi-nn test programs. -const PREOPENED_DIR_NAME: &str = "fixture"; +pub mod wit; +pub mod witx; -/// Run a wasi-nn test program. This is modeled after -/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API -/// for file reads. -pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { - let path = Path::new(path); - let engine = Engine::new(&Config::new())?; - let mut linker = Linker::new(&engine); - wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?; - wasi_common::sync::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi)?; - let module = Module::from_file(&engine, path)?; - let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); - let instance = linker.instantiate(&mut store, &module)?; - let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?; - start.call(&mut store, ())?; - Ok(()) -} - -/// The host state for running wasi-nn tests. -struct Ctx { - wasi: WasiCtx, - wasi_nn: WasiNnCtx, -} - -impl Ctx { - fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result { - let preopen_dir = Dir::open_ambient_dir(preopen_dir, cap_std::ambient_authority())?; - let mut builder = WasiCtxBuilder::new(); - builder - .inherit_stdio() - .preopened_dir(preopen_dir, PREOPENED_DIR_NAME)?; - let wasi = builder.build(); - - let mut registry = InMemoryRegistry::new(); - let mobilenet_dir = artifacts_dir(); - if preload_model { - registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; - } - let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); - - Ok(Self { wasi, wasi_nn }) - } -} +pub const PREOPENED_DIR_NAME: &str = "fixture"; diff --git a/crates/wasi-nn/tests/exec/wit.rs b/crates/wasi-nn/tests/exec/wit.rs new file mode 100644 index 000000000000..5f2d546d667d --- /dev/null +++ b/crates/wasi-nn/tests/exec/wit.rs @@ -0,0 +1,73 @@ +use super::PREOPENED_DIR_NAME; +use crate::check::artifacts_dir; +use anyhow::{anyhow, Result}; +use std::path::Path; +use wasmtime::component::{Component, Linker, ResourceTable}; +use wasmtime::{Config, Engine, Store}; +use wasmtime_wasi::bindings::sync::Command; +use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder}; +use wasmtime_wasi_nn::wit::WasiNnView; +use wasmtime_wasi_nn::{wit::WasiNnCtx, Backend, InMemoryRegistry}; + +/// Run a wasi-nn test program. This is modeled after +/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API for +/// file reads. +pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { + let path = Path::new(path); + let engine = Engine::new(&Config::new())?; + let mut linker = Linker::new(&engine); + wasmtime_wasi_nn::wit::add_to_linker(&mut linker, |c: &mut Ctx| { + WasiNnView::new(&mut c.table, &mut c.wasi_nn) + })?; + wasmtime_wasi::add_to_linker_sync(&mut linker)?; + let module = Component::from_file(&engine, path)?; + let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); + let command = Command::instantiate(&mut store, &module, &linker)?; + let result = command.wasi_cli_run().call_run(&mut store)?; + result.map_err(|_| anyhow!("failed to run command")) +} + +/// The host state for running wasi-nn component tests. +struct Ctx { + wasi: WasiCtx, + wasi_nn: WasiNnCtx, + table: ResourceTable, +} + +impl Ctx { + fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result { + let mut builder = WasiCtxBuilder::new(); + builder.inherit_stdio().preopened_dir( + preopen_dir, + PREOPENED_DIR_NAME, + DirPerms::READ, + FilePerms::READ, + )?; + let wasi = builder.build(); + + let mut registry = InMemoryRegistry::new(); + let mobilenet_dir = artifacts_dir(); + if preload_model { + registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; + } + let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); + + let table = ResourceTable::new(); + + Ok(Self { + wasi, + wasi_nn, + table, + }) + } +} + +impl wasmtime_wasi::WasiView for Ctx { + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.wasi + } + + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } +} diff --git a/crates/wasi-nn/tests/exec/witx.rs b/crates/wasi-nn/tests/exec/witx.rs new file mode 100644 index 000000000000..21feea1dd641 --- /dev/null +++ b/crates/wasi-nn/tests/exec/witx.rs @@ -0,0 +1,52 @@ +use super::PREOPENED_DIR_NAME; +use crate::check::artifacts_dir; +use anyhow::Result; +use std::path::Path; +use wasmtime::{Config, Engine, Linker, Module, Store}; +use wasmtime_wasi::{preview1::WasiP1Ctx, DirPerms, FilePerms, WasiCtxBuilder}; +use wasmtime_wasi_nn::{witx::WasiNnCtx, Backend, InMemoryRegistry}; + +/// Run a wasi-nn test program. This is modeled after +/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API +/// for file reads. +pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { + let path = Path::new(path); + let engine = Engine::new(&Config::new())?; + let mut linker = Linker::new(&engine); + wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?; + wasmtime_wasi::preview1::add_to_linker_sync(&mut linker, |s: &mut Ctx| &mut s.wasi)?; + let module = Module::from_file(&engine, path)?; + let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); + let instance = linker.instantiate(&mut store, &module)?; + let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?; + start.call(&mut store, ())?; + Ok(()) +} + +/// The host state for running wasi-nn tests. +struct Ctx { + wasi: WasiP1Ctx, + wasi_nn: WasiNnCtx, +} + +impl Ctx { + fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result { + let mut builder = WasiCtxBuilder::new(); + builder.inherit_stdio().preopened_dir( + preopen_dir, + PREOPENED_DIR_NAME, + DirPerms::READ, + FilePerms::READ, + )?; + let wasi = builder.build_p1(); + + let mut registry = InMemoryRegistry::new(); + let mobilenet_dir = artifacts_dir(); + if preload_model { + registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; + } + let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); + + Ok(Self { wasi, wasi_nn }) + } +} diff --git a/crates/wasi-nn/tests/fixtures/readme.md b/crates/wasi-nn/tests/fixtures/README.md similarity index 100% rename from crates/wasi-nn/tests/fixtures/readme.md rename to crates/wasi-nn/tests/fixtures/README.md diff --git a/crates/wasi-nn/tests/test-programs.rs b/crates/wasi-nn/tests/test-programs.rs index 6dfb89a90e58..e375063b1cd8 100644 --- a/crates/wasi-nn/tests/test-programs.rs +++ b/crates/wasi-nn/tests/test-programs.rs @@ -23,6 +23,8 @@ use test_programs_artifacts::*; use wasmtime_wasi_nn::{backend, Backend}; fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + if cfg!(miri) { return Ok(()); } @@ -45,7 +47,7 @@ fn main() -> Result<()> { let mut trials = Vec::new(); for program in programs { // Either ignore the test if it cannot run (i.e., downgrade `Fail` to - // `Ignore`) or pre-emptively fail it if `error_on_failed_check` is set. + // `Ignore`) or preemptively fail it if `error_on_failed_check` is set. let (run_test, mut check) = check_test_program(program); if !error_on_failed_check { check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`. @@ -68,103 +70,122 @@ fn main() -> Result<()> { /// Return the test program to run and a check that must pass for the test to /// run. fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) { - use IgnoreCheck::*; match name { - "nn_image_classification" => ( - nn_image_classification, - if !cfg!(target_arch = "x86_64") { - Fail("requires x86_64".into()) - } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { - Fail("requires linux or windows".into()) - } else if let Err(e) = check::openvino::is_installed() { - Fail(e.to_string().into()) - } else { - Run - }, + // Legacy WITX-based tests: + "nn_witx_image_classification_openvino" => ( + nn_witx_image_classification_openvino, + IgnoreCheck::for_openvino(), + ), + "nn_witx_image_classification_openvino_named" => ( + nn_witx_image_classification_openvino_named, + IgnoreCheck::for_openvino(), + ), + "nn_witx_image_classification_onnx" => { + (nn_witx_image_classification_onnx, IgnoreCheck::for_onnx()) + } + "nn_witx_image_classification_winml_named" => ( + nn_witx_image_classification_winml_named, + IgnoreCheck::for_winml(), ), - "nn_image_classification_named" => ( - nn_image_classification_named, - if !cfg!(target_arch = "x86_64") { - Fail("requires x86_64".into()) - } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { - Fail("requires linux or windows or macos".into()) - } else if let Err(e) = check::openvino::is_installed() { - Fail(e.to_string().into()) - } else { - Run - }, + // WIT-based tests: + "nn_wit_image_classification_openvino" => ( + nn_wit_image_classification_openvino, + IgnoreCheck::for_openvino(), ), - "nn_image_classification_onnx" => ( - nn_image_classification_onnx, - #[cfg(feature = "onnx")] - if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") { - Fail("requires x86_64 or aarch64".into()) - } else if !cfg!(target_os = "linux") - && !cfg!(target_os = "windows") - && !cfg!(target_os = "macos") - { - Fail("requires linux, windows, or macos".into()) - } else { - Run - }, - #[cfg(not(feature = "onnx"))] - Ignore("requires the `onnx` feature".into()), + "nn_wit_image_classification_openvino_named" => ( + nn_wit_image_classification_openvino_named, + IgnoreCheck::for_openvino(), ), - "nn_image_classification_winml" => ( - nn_image_classification_winml, - #[cfg(all(feature = "winml", target_os = "windows"))] - if !cfg!(target_arch = "x86_64") { - Fail("requires x86_64".into()) - } else if cfg!(target_os = "windows") { - Fail("requires windows".into()) - } else if let Err(e) = check::winml::is_available() { - Fail(e.to_string().into()) - } else { - Run - }, - #[cfg(not(all(feature = "winml", target_os = "windows")))] - Ignore("requires the `winml` feature on windows".into()), + "nn_wit_image_classification_onnx" => { + (nn_wit_image_classification_onnx, IgnoreCheck::for_onnx()) + } + "nn_wit_image_classification_winml_named" => ( + nn_wit_image_classification_winml_named, + IgnoreCheck::for_winml(), ), _ => panic!("unknown test program: {} (add to this `match`)", name), } } -fn nn_image_classification() -> Result<()> { +fn nn_witx_image_classification_openvino() -> Result<()> { check::openvino::is_installed()?; check::openvino::are_artifacts_available()?; let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION, backend, false) + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO, backend, false) } -fn nn_image_classification_named() -> Result<()> { +fn nn_witx_image_classification_openvino_named() -> Result<()> { check::openvino::is_installed()?; check::openvino::are_artifacts_available()?; let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION_NAMED, backend, true) + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO_NAMED, backend, true) } #[cfg(feature = "onnx")] -fn nn_image_classification_onnx() -> Result<()> { +fn nn_witx_image_classification_onnx() -> Result<()> { check::onnx::are_artifacts_available()?; - let backend = Backend::from(backend::onnxruntime::OnnxBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false) + let backend = Backend::from(backend::onnx::OnnxBackend::default()); + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false) } - #[cfg(not(feature = "onnx"))] -fn nn_image_classification_onnx() -> Result<()> { +fn nn_witx_image_classification_onnx() -> Result<()> { anyhow::bail!("this test requires the `onnx` feature") } #[cfg(all(feature = "winml", target_os = "windows"))] -fn nn_image_classification_winml() -> Result<()> { +fn nn_witx_image_classification_winml_named() -> Result<()> { check::winml::is_available()?; check::onnx::are_artifacts_available()?; let backend = Backend::from(backend::winml::WinMLBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false) + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false) +} +#[cfg(not(all(feature = "winml", target_os = "windows")))] +fn nn_witx_image_classification_winml_named() -> Result<()> { + anyhow::bail!("this test requires the `winml` feature and only runs on windows") } +fn nn_wit_image_classification_openvino() -> Result<()> { + check::openvino::is_installed()?; + check::openvino::are_artifacts_available()?; + let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); + exec::wit::run( + NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_COMPONENT, + backend, + false, + ) +} + +fn nn_wit_image_classification_openvino_named() -> Result<()> { + check::openvino::is_installed()?; + check::openvino::are_artifacts_available()?; + let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); + exec::wit::run( + NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_NAMED_COMPONENT, + backend, + true, + ) +} + +#[cfg(feature = "onnx")] +fn nn_wit_image_classification_onnx() -> Result<()> { + check::onnx::are_artifacts_available()?; + let backend = Backend::from(backend::onnx::OnnxBackend::default()); + exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false) +} +#[cfg(not(feature = "onnx"))] +fn nn_wit_image_classification_onnx() -> Result<()> { + anyhow::bail!("this test requires the `onnx` feature") +} + +#[cfg(all(feature = "winml", target_os = "windows"))] +fn nn_wit_image_classification_winml_named() -> Result<()> { + check::winml::is_available()?; + check::onnx::are_artifacts_available()?; + let backend = Backend::from(backend::winml::WinMLBackend::default()); + exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false) +} #[cfg(not(all(feature = "winml", target_os = "windows")))] -fn nn_image_classification_winml() -> Result<()> { +fn nn_wit_image_classification_winml_named() -> Result<()> { anyhow::bail!("this test requires the `winml` feature and only runs on windows") } @@ -197,3 +218,52 @@ impl IgnoreCheck { matches!(self, IgnoreCheck::Ignore(_)) } } + +/// Some pre-test checks for various backends. +impl IgnoreCheck { + fn for_openvino() -> IgnoreCheck { + use IgnoreCheck::*; + if !cfg!(target_arch = "x86_64") { + Fail("requires x86_64".into()) + } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { + Fail("requires linux or windows or macos".into()) + } else if let Err(e) = check::openvino::is_installed() { + Fail(e.to_string().into()) + } else { + Run + } + } + + fn for_onnx() -> Self { + use IgnoreCheck::*; + #[cfg(feature = "onnx")] + if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") { + Fail("requires x86_64 or aarch64".into()) + } else if !cfg!(target_os = "linux") + && !cfg!(target_os = "windows") + && !cfg!(target_os = "macos") + { + Fail("requires linux, windows, or macos".into()) + } else { + Run + } + #[cfg(not(feature = "onnx"))] + Ignore("requires the `onnx` feature".into()) + } + + fn for_winml() -> IgnoreCheck { + use IgnoreCheck::*; + #[cfg(all(feature = "winml", target_os = "windows"))] + if !cfg!(target_arch = "x86_64") { + Fail("requires x86_64".into()) + } else if !cfg!(target_os = "windows") { + Fail("requires windows".into()) + } else if let Err(e) = check::winml::is_available() { + Fail(e.to_string().into()) + } else { + Run + } + #[cfg(not(all(feature = "winml", target_os = "windows")))] + Ignore("requires the `winml` feature on windows".into()) + } +} diff --git a/crates/wasi-nn/wit/wasi-nn.wit b/crates/wasi-nn/wit/wasi-nn.wit index 19e3de875d61..b8ffd22e8c04 100644 --- a/crates/wasi-nn/wit/wasi-nn.wit +++ b/crates/wasi-nn/wit/wasi-nn.wit @@ -43,16 +43,18 @@ interface tensor { /// memory--e.g., using row-major ordering--and could perhaps be improved. type tensor-data = list; - record tensor { + resource tensor { + constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data); + // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor // containing a single value, use `[1]` for the tensor dimensions. - dimensions: tensor-dimensions, + dimensions: func() -> tensor-dimensions; // Describe the type of element in the tensor (e.g., `f32`). - tensor-type: tensor-type, + ty: func() -> tensor-type; - // Contains the tensor data. - data: tensor-data, + // Return the tensor data. + data: func() -> tensor-data; } } @@ -61,11 +63,12 @@ interface tensor { interface graph { use errors.{error}; use tensor.{tensor}; + use inference.{graph-execution-context}; /// An execution graph for performing inference (i.e., a model). - /// - /// TODO: replace with `resource` (https://github.com/WebAssembly/wasi-nn/issues/47). - type graph = u32; + resource graph { + init-execution-context: func() -> result; + } /// Describes the encoding of the graph. This allows the API to be implemented by various /// backends that encode (i.e., serialize) their graph IR with different formats. @@ -75,6 +78,7 @@ interface graph { tensorflow, pytorch, tensorflowlite, + ggml, autodetect, } @@ -107,27 +111,25 @@ interface graph { interface inference { use errors.{error}; use tensor.{tensor, tensor-data}; - use graph.{graph}; /// Bind a `graph` to the input and output tensors for an inference. /// - /// TODO: this is no longer necessary in WIT (https://github.com/WebAssembly/wasi-nn/issues/43) - type graph-execution-context = u32; - - /// Create an execution instance of a loaded graph. - init-execution-context: func(graph: graph) -> result; - - /// Define the inputs to use for inference. - set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>; - - /// Compute the inference on the given inputs. - /// - /// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this - /// expectation could be removed as a part of https://github.com/WebAssembly/wasi-nn/issues/43. - compute: func(ctx: graph-execution-context) -> result<_, error>; - - /// Extract the outputs after inference. - get-output: func(ctx: graph-execution-context, index: u32) -> result; + /// TODO: this may no longer be necessary in WIT + /// (https://github.com/WebAssembly/wasi-nn/issues/43) + resource graph-execution-context { + /// Define the inputs to use for inference. + set-input: func(name: string, tensor: tensor) -> result<_, error>; + + /// Compute the inference on the given inputs. + /// + /// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this + /// expectation could be removed as a part of + /// https://github.com/WebAssembly/wasi-nn/issues/43. + compute: func() -> result<_, error>; + + /// Extract the outputs after inference. + get-output: func(name: string) -> result; + } } /// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42) @@ -137,7 +139,8 @@ interface errors { invalid-argument, // Invalid encoding. invalid-encoding, - busy, + // The operation timed out. + timeout, // Runtime Error. runtime-error, // Unsupported operation. diff --git a/src/commands/run.rs b/src/commands/run.rs index 317b91aa22d7..e5b3b816a28a 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -18,7 +18,7 @@ use wasmtime::{Engine, Func, Module, Store, StoreLimits, Val, ValType}; use wasmtime_wasi::WasiView; #[cfg(feature = "wasi-nn")] -use wasmtime_wasi_nn::WasiNnCtx; +use wasmtime_wasi_nn::wit::WasiNnView; #[cfg(feature = "wasi-threads")] use wasmtime_wasi_threads::WasiThreadsCtx; @@ -624,40 +624,37 @@ impl RunCommand { { bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); } - #[cfg(feature = "wasi-nn")] + #[cfg(all(feature = "wasi-nn", feature = "component-model"))] { + let (backends, registry) = self.collect_preloaded_nn_graphs()?; match linker { CliLinker::Core(linker) => { wasmtime_wasi_nn::witx::add_to_linker(linker, |host| { - // This WASI proposal is currently not protected against - // concurrent access--i.e., when wasi-threads is actively - // spawning new threads, we cannot (yet) safely allow access and - // fail if more than one thread has `Arc`-references to the - // context. Once this proposal is updated (as wasi-common has - // been) to allow concurrent access, this `Arc::get_mut` - // limitation can be removed. - Arc::get_mut(host.wasi_nn.as_mut().unwrap()) + Arc::get_mut(host.wasi_nn_witx.as_mut().unwrap()) .expect("wasi-nn is not implemented with multi-threading support") })?; + store.data_mut().wasi_nn_witx = Some(Arc::new( + wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry), + )); } #[cfg(feature = "component-model")] CliLinker::Component(linker) => { - wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| { - Arc::get_mut(host.wasi_nn.as_mut().unwrap()) - .expect("wasi-nn is not implemented with multi-threading support") + wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| { + let preview2_ctx = + h.preview2_ctx.as_mut().expect("wasip2 is not configured"); + let preview2_ctx = Arc::get_mut(preview2_ctx) + .expect("wasmtime_wasi is not compatible with threads") + .get_mut() + .unwrap(); + let nn_ctx = Arc::get_mut(h.wasi_nn_wit.as_mut().unwrap()) + .expect("wasi-nn is not implemented with multi-threading support"); + WasiNnView::new(preview2_ctx.table(), nn_ctx) })?; + store.data_mut().wasi_nn_wit = Some(Arc::new( + wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry), + )); } } - let graphs = self - .run - .common - .wasi - .nn_graph - .iter() - .map(|g| (g.format.clone(), g.dir.clone())) - .collect::>(); - let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?; - store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new(backends, registry))); } } @@ -767,6 +764,21 @@ impl RunCommand { store.data_mut().preview2_ctx = Some(Arc::new(Mutex::new(ctx))); Ok(()) } + + #[cfg(feature = "wasi-nn")] + fn collect_preloaded_nn_graphs( + &self, + ) -> Result<(Vec, wasmtime_wasi_nn::Registry)> { + let graphs = self + .run + .common + .wasi + .nn_graph + .iter() + .map(|g| (g.format.clone(), g.dir.clone())) + .collect::>(); + wasmtime_wasi_nn::preload(&graphs) + } } #[derive(Default, Clone)] @@ -779,7 +791,10 @@ struct Host { preview2_ctx: Option>>, #[cfg(feature = "wasi-nn")] - wasi_nn: Option>, + wasi_nn_wit: Option>, + #[cfg(feature = "wasi-nn")] + wasi_nn_witx: Option>, + #[cfg(feature = "wasi-threads")] wasi_threads: Option>>, #[cfg(feature = "wasi-http")] diff --git a/src/commands/serve.rs b/src/commands/serve.rs index 56c2af9d3024..8a200ee093f3 100644 --- a/src/commands/serve.rs +++ b/src/commands/serve.rs @@ -17,7 +17,7 @@ use wasmtime_wasi_http::io::TokioIo; use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView}; #[cfg(feature = "wasi-nn")] -use wasmtime_wasi_nn::WasiNnCtx; +use wasmtime_wasi_nn::wit::WasiNnCtx; struct Host { table: wasmtime::component::ResourceTable, @@ -75,15 +75,8 @@ impl ServeCommand { pub fn execute(mut self) -> Result<()> { self.run.common.init_logging()?; - // We force cli errors before starting to listen for connections so then we don't - // accidentally delay them to the first request. - if self.run.common.wasi.nn == Some(true) { - #[cfg(not(feature = "wasi-nn"))] - { - bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); - } - } - + // We force cli errors before starting to listen for connections so then + // we don't accidentally delay them to the first request. if let Some(Profile::Guest { .. }) = &self.run.profile { bail!("Cannot use the guest profiler with components"); } @@ -99,8 +92,8 @@ impl ServeCommand { bail!("wasi-threads does not support components yet") } - // The serve command requires both wasi-http and the component model, so we enable those by - // default here. + // The serve command requires both wasi-http and the component model, so + // we enable those by default here. if self.run.common.wasi.http.replace(true) == Some(false) { bail!("wasi-http is required for the serve command, and must not be disabled"); } @@ -227,7 +220,10 @@ impl ServeCommand { } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| host.nn.as_mut().unwrap())?; + wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| { + let ctx = h.nn.as_mut().unwrap(); + wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx) + })?; } } diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index e29b1cfc4fd0..7cc1fa51a95a 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -2028,6 +2028,12 @@ criteria = "safe-to-deploy" version = "0.46.0" notes = "one use of unsafe to call windows specific api to get console handle." +[[audits.num-traits]] +who = "Andrew Brown " +criteria = "safe-to-deploy" +version = "0.2.19" +notes = "As advertised: a numeric library. The only `unsafe` is from some float-to-int conversions, which seems expected." + [[audits.num_cpus]] who = "Alex Crichton " criteria = "safe-to-deploy" @@ -2145,12 +2151,24 @@ criteria = "safe-to-deploy" version = "2.0.0-rc.0" notes = "As expected, this crate uses `unsafe` to access the `unsafe` `ort-sys` FFI functions; it also includes several `unsafe` implementations of `Send` for several structures. With the `load-dynamic` feature enabled, this crate will be `libloading` external libraries to call FFI functions. With the `fetch-models` feature enabled, this crate can also download arbitrary models to the local filesystem." +[[audits.ort]] +who = "Andrew Brown " +criteria = "safe-to-deploy" +delta = "2.0.0-rc.0 -> 2.0.0-rc.2" +notes = "Same as previous audit: the crate inherently uses `unsafe` FFI calls for using ONNX through `ort-sys` (e.g., logging C error strings). The changes are relatively uninteresting: a lot of documentation, some `must_use`, and general refactoring due to changes in the underlying API." + [[audits.ort-sys]] who = "Andrew Brown " criteria = "safe-to-deploy" version = "2.0.0-rc.0" notes = "As expected, this crate contains a significant number of `unsafe` definitions to expose the FFI surface of the ONNX libraries. Perhaps surprisingly, it also contains some `unsafe` system calls to locate the user's home directory. Another interesting bit is the `build.rs` script: with the `download-binaries` feature enabled, this script will retrieve and link various ONNX libraries from https://parcel.pyke.io. This seems par for the course with this kind of library, though; the alternative--attempting to find the library on an arbitrary system--can be quite complex." +[[audits.ort-sys]] +who = "Andrew Brown " +criteria = "safe-to-deploy" +delta = "2.0.0-rc.0 -> 2.0.0-rc.2" +notes = "This crate still downloads the ONNX libraries as a part of the `build.rs` script; now with more platform options for pre-built binaries stored in a `dist.txt` file. Otherwise largely unchanged since the previous audit." + [[audits.overload]] who = "Pat Hickey " criteria = "safe-to-deploy"