From 768c77e5916656b0a00df177740155393c50804a Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Tue, 25 Jun 2024 12:04:29 -0700 Subject: [PATCH 1/5] wasi-nn: use resources Recent discussion in the wasi-nn proposal (see [wasi-nn#59], e.g.) has concluded that the right approach for representing wasi-nn "things" (tensors, graph, etc.) is with a component model _resource_. This sweeping change brings Wasmtime's implementation in line with that decision. Initially I had structured this PR to remove all of the WITX-based implementation (#8530). But, after consulting in a Zulip [thread] on what other WASI proposals aim to do, this PR pivoted to support _both_` the WITX-based and WIT-based ABIs (e.g., preview1 era versus preview2, component model era). What is clear is that the WITX-based specification will remain "frozen in time" while the WIT-based implementation moves forward. What that means for this PR is a "split world" paradigm. In many places, we have to distinguish between the `wit` and `witx` versions of the same thing. This change isn't the end state yet: it's a big step forward towards bringing Wasmtime back in line with the WIT spec but, despite my best efforts, doesn't fully fix all the TODOs left behind over several years of development. I have, however, taken the liberty to refactor and fix various parts as I came across them (e.g., the ONNX backend). I plan to continue working on this in future PRs to figure out a good error paradigm (the current one is too wordy) and device residence. [wasi-nn#59]: https://github.com/WebAssembly/wasi-nn/pull/59 [thread]: https://bytecodealliance.zulipchat.com/#narrow/stream/219900-wasi/topic/wasi-nn's.20preview1.20vs.20preview2.20timeline prtest:full --- Cargo.lock | 50 ++- crates/bench-api/src/lib.rs | 4 +- crates/test-programs/artifacts/build.rs | 5 +- .../src/bin/nn_image_classification_winml.rs | 16 - ...rs => nn_wit_image_classification_onnx.rs} | 14 +- .../nn_wit_image_classification_openvino.rs | 25 ++ ...wit_image_classification_openvino_named.rs | 17 + ...n_wit_image_classification_winml_named.rs} | 9 +- .../bin/nn_witx_image_classification_onnx.rs | 22 ++ ... nn_witx_image_classification_openvino.rs} | 12 +- ...itx_image_classification_openvino_named.rs | 17 + ...n_witx_image_classification_winml_named.rs | 18 + crates/test-programs/src/nn.rs | 180 ++++++++-- crates/wasi-nn/Cargo.toml | 17 +- crates/wasi-nn/src/backend/mod.rs | 31 +- crates/wasi-nn/src/backend/onnx.rs | 338 ++++++++++++++++++ crates/wasi-nn/src/backend/onnxruntime.rs | 149 -------- crates/wasi-nn/src/backend/openvino.rs | 36 +- crates/wasi-nn/src/backend/winml.rs | 146 +++++--- crates/wasi-nn/src/ctx.rs | 146 -------- crates/wasi-nn/src/lib.rs | 51 ++- crates/wasi-nn/src/registry/in_memory.rs | 5 +- crates/wasi-nn/src/registry/mod.rs | 1 + crates/wasi-nn/src/wit.rs | 308 ++++++++++++---- crates/wasi-nn/src/witx.rs | 122 ++++++- crates/wasi-nn/tests/check/mod.rs | 8 +- crates/wasi-nn/tests/exec/mod.rs | 54 +-- crates/wasi-nn/tests/exec/wit.rs | 80 +++++ crates/wasi-nn/tests/exec/witx.rs | 52 +++ .../tests/fixtures/{readme.md => README.md} | 0 crates/wasi-nn/tests/test-programs.rs | 200 +++++++---- crates/wasi-nn/wit/wasi-nn.wit | 57 +-- src/commands/run.rs | 67 ++-- src/commands/serve.rs | 30 +- 34 files changed, 1547 insertions(+), 740 deletions(-) delete mode 100644 crates/test-programs/src/bin/nn_image_classification_winml.rs rename crates/test-programs/src/bin/{nn_image_classification_onnx.rs => nn_wit_image_classification_onnx.rs} (67%) create mode 100644 crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs create mode 100644 crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs rename crates/test-programs/src/bin/{nn_image_classification_named.rs => nn_wit_image_classification_winml_named.rs} (53%) create mode 100644 crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs rename crates/test-programs/src/bin/{nn_image_classification.rs => nn_witx_image_classification_openvino.rs} (66%) create mode 100644 crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs create mode 100644 crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs create mode 100644 crates/wasi-nn/src/backend/onnx.rs delete mode 100644 crates/wasi-nn/src/backend/onnxruntime.rs delete mode 100644 crates/wasi-nn/src/ctx.rs create mode 100644 crates/wasi-nn/tests/exec/wit.rs create mode 100644 crates/wasi-nn/tests/exec/witx.rs rename crates/wasi-nn/tests/fixtures/{readme.md => README.md} (100%) diff --git a/Cargo.lock b/Cargo.lock index eccee7eb5785..47e695e49d79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -371,15 +371,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "castaway" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a17ed5635fc8536268e5d4de1e22e81ac34419e5f052d4d51f4e01dcc263fcc" -dependencies = [ - "rustversion", -] - [[package]] name = "cc" version = "1.0.83" @@ -486,19 +477,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" -[[package]] -name = "compact_str" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "ryu", - "static_assertions", -] - [[package]] name = "component-fuzz-util" version = "0.0.0" @@ -1908,9 +1886,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] @@ -2028,21 +2006,22 @@ dependencies = [ [[package]] name = "ort" -version = "2.0.0-rc.0" +version = "2.0.0-rc.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8e5caf4eb2ead4bc137c3ff4e347940e3e556ceb11a4180627f04b63d7342dd" +checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" dependencies = [ - "compact_str", + "js-sys", "ort-sys", "thiserror", "tracing", + "web-sys", ] [[package]] name = "ort-sys" -version = "2.0.0-rc.0" +version = "2.0.0-rc.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f48b5623df2187e0db543ecb2032a6a999081086b7ffddd318000c00b23ace46" +checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" dependencies = [ "flate2", "sha2", @@ -3936,9 +3915,10 @@ dependencies = [ "test-programs-artifacts", "thiserror", "tracing", + "tracing-subscriber", "walkdir", - "wasi-common", "wasmtime", + "wasmtime-wasi", "wiggle", "windows", ] @@ -4024,6 +4004,16 @@ dependencies = [ "wast 211.0.1", ] +[[package]] +name = "web-sys" +version = "0.3.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b17e741662c70c8bd24ac5c5b18de314a2c26c32bf8346ee1e6f53de919c283" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.1" diff --git a/crates/bench-api/src/lib.rs b/crates/bench-api/src/lib.rs index 166db3dc2fe4..9b89faddcbda 100644 --- a/crates/bench-api/src/lib.rs +++ b/crates/bench-api/src/lib.rs @@ -418,7 +418,7 @@ struct BenchState { struct HostState { wasi: WasiCtx, #[cfg(feature = "wasi-nn")] - wasi_nn: wasmtime_wasi_nn::WasiNnCtx, + wasi_nn: wasmtime_wasi_nn::witx::WasiNnCtx, } impl BenchState { @@ -509,7 +509,7 @@ impl BenchState { #[cfg(feature = "wasi-nn")] wasi_nn: { let (backends, registry) = wasmtime_wasi_nn::preload(&[])?; - wasmtime_wasi_nn::WasiNnCtx::new(backends, registry) + wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry) }, }; diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index 20209eaf6de6..7fd653f9b72a 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -90,7 +90,10 @@ fn build_and_generate_tests() { } // Generate a component from each test. - if kind == "nn" || target == "dwarf_imported_memory" || target == "dwarf_shared_memory" { + if target == "dwarf_imported_memory" + || target == "dwarf_shared_memory" + || target.starts_with("nn_witx") + { continue; } let adapter = match target.as_str() { diff --git a/crates/test-programs/src/bin/nn_image_classification_winml.rs b/crates/test-programs/src/bin/nn_image_classification_winml.rs deleted file mode 100644 index 0dc7e8843525..000000000000 --- a/crates/test-programs/src/bin/nn_image_classification_winml.rs +++ /dev/null @@ -1,16 +0,0 @@ -use anyhow::{Context, Result}; -use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; - -pub fn main() -> Result<()> { - let graph = GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU) - .build_from_cache("mobilenet")?; - let tensor = fs::read("fixture/kitten.rgb") - .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; - let top_five = &sort_results(&results)[..5]; - println!("found results, sorted top 5: {:?}", top_five); - assert_eq!(top_five[0].class_id(), 284); - Ok(()) -} diff --git a/crates/test-programs/src/bin/nn_image_classification_onnx.rs b/crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs similarity index 67% rename from crates/test-programs/src/bin/nn_image_classification_onnx.rs rename to crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs index abb77d0e7339..5900e664afde 100644 --- a/crates/test-programs/src/bin/nn_image_classification_onnx.rs +++ b/crates/test-programs/src/bin/nn_wit_image_classification_onnx.rs @@ -1,18 +1,20 @@ use anyhow::{Context, Result}; use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; +use test_programs::nn::{sort_results, wit}; pub fn main() -> Result<()> { let model = fs::read("fixture/model.onnx") .context("the model file to be mapped to the fixture directory")?; - let graph = - GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?; + let graph = wit::load( + &[model], + wit::GraphEncoding::Onnx, + wit::ExecutionTarget::Cpu, + )?; let tensor = fs::read("fixture/000000062808.rgb") .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; + let results = wit::classify(graph, ("input", tensor), "output")?; let top_five = &sort_results(&results)[..5]; - // 963 is meat loaf, meatloaf. + // 963 is "meat loaf, meatloaf." // https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963 assert_eq!(top_five[0].class_id(), 963); println!("found results, sorted top 5: {:?}", top_five); diff --git a/crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs b/crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs new file mode 100644 index 000000000000..52bb6eb5967c --- /dev/null +++ b/crates/test-programs/src/bin/nn_wit_image_classification_openvino.rs @@ -0,0 +1,25 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, wit}; + +pub fn main() -> Result<()> { + let xml = fs::read("fixture/model.xml") + .context("the model file to be mapped to the fixture directory")?; + let weights = fs::read("fixture/model.bin") + .context("the weights file to be mapped to the fixture directory")?; + let graph = wit::load( + &[xml, weights], + wit::GraphEncoding::Openvino, + wit::ExecutionTarget::Cpu, + )?; + let tensor = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = wit::classify( + graph, + ("input", tensor), + "MobilenetV2/Predictions/Reshape_1", + )?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs b/crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs new file mode 100644 index 000000000000..482c77043206 --- /dev/null +++ b/crates/test-programs/src/bin/nn_wit_image_classification_openvino_named.rs @@ -0,0 +1,17 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, wit}; + +pub fn main() -> Result<()> { + let graph = wit::load_by_name("fixtures")?; + let tensor: Vec = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = wit::classify( + graph, + ("input", tensor), + "MobilenetV2/Predictions/Reshape_1", + )?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_image_classification_named.rs b/crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs similarity index 53% rename from crates/test-programs/src/bin/nn_image_classification_named.rs rename to crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs index 9b75a5afb4c6..f3840062e207 100644 --- a/crates/test-programs/src/bin/nn_image_classification_named.rs +++ b/crates/test-programs/src/bin/nn_wit_image_classification_winml_named.rs @@ -1,15 +1,14 @@ use anyhow::{Context, Result}; use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; +use test_programs::nn::{sort_results, wit}; pub fn main() -> Result<()> { - let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) - .build_from_cache("fixtures")?; + let graph = wit::load_by_name("mobilenet")?; let tensor = fs::read("fixture/tensor.bgr") .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; + let results = wit::classify(graph, ("input", tensor), "output")?; let top_five = &sort_results(&results)[..5]; println!("found results, sorted top 5: {:?}", top_five); + assert_eq!(top_five[0].class_id(), 284); Ok(()) } diff --git a/crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs b/crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs new file mode 100644 index 000000000000..fc4991ca44ff --- /dev/null +++ b/crates/test-programs/src/bin/nn_witx_image_classification_onnx.rs @@ -0,0 +1,22 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, witx}; + +pub fn main() -> Result<()> { + let model = fs::read("fixture/model.onnx") + .context("the model file to be mapped to the fixture directory")?; + let graph = witx::load( + &[&model], + witx::GraphEncoding::Onnx, + witx::ExecutionTarget::CPU, + )?; + let tensor = fs::read("fixture/000000062808.rgb") + .context("the tensor file to be mapped to the fixture directory")?; + let results = witx::classify(graph, tensor)?; + let top_five = &sort_results(&results)[..5]; + // 963 is "meat loaf, meatloaf." + // https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963 + assert_eq!(top_five[0].class_id(), 963); + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_image_classification.rs b/crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs similarity index 66% rename from crates/test-programs/src/bin/nn_image_classification.rs rename to crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs index 5815503c3f76..5d557e3b8274 100644 --- a/crates/test-programs/src/bin/nn_image_classification.rs +++ b/crates/test-programs/src/bin/nn_witx_image_classification_openvino.rs @@ -1,18 +1,20 @@ use anyhow::{Context, Result}; use std::fs; -use test_programs::nn::{classify, sort_results}; -use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding}; +use test_programs::nn::{sort_results, witx}; pub fn main() -> Result<()> { let xml = fs::read("fixture/model.xml") .context("the model file to be mapped to the fixture directory")?; let weights = fs::read("fixture/model.bin") .context("the weights file to be mapped to the fixture directory")?; - let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) - .build_from_bytes([&xml, &weights])?; + let graph = witx::load( + &[&xml, &weights], + witx::GraphEncoding::Openvino, + witx::ExecutionTarget::CPU, + )?; let tensor = fs::read("fixture/tensor.bgr") .context("the tensor file to be mapped to the fixture directory")?; - let results = classify(graph, tensor)?; + let results = witx::classify(graph, tensor)?; let top_five = &sort_results(&results)[..5]; println!("found results, sorted top 5: {:?}", top_five); Ok(()) diff --git a/crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs b/crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs new file mode 100644 index 000000000000..d91c78a5c7b5 --- /dev/null +++ b/crates/test-programs/src/bin/nn_witx_image_classification_openvino_named.rs @@ -0,0 +1,17 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, witx}; + +pub fn main() -> Result<()> { + let graph = witx::load_by_name( + "fixtures", + witx::GraphEncoding::Openvino, + witx::ExecutionTarget::CPU, + )?; + let tensor: Vec = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = witx::classify(graph, tensor)?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + Ok(()) +} diff --git a/crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs b/crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs new file mode 100644 index 000000000000..87808f32c2c3 --- /dev/null +++ b/crates/test-programs/src/bin/nn_witx_image_classification_winml_named.rs @@ -0,0 +1,18 @@ +use anyhow::{Context, Result}; +use std::fs; +use test_programs::nn::{sort_results, witx}; + +pub fn main() -> Result<()> { + let graph = witx::load_by_name( + "mobilenet", + witx::GraphEncoding::Onnx, + witx::ExecutionTarget::CPU, + )?; + let tensor = fs::read("fixture/tensor.bgr") + .context("the tensor file to be mapped to the fixture directory")?; + let results = witx::classify(graph, tensor)?; + let top_five = &sort_results(&results)[..5]; + println!("found results, sorted top 5: {:?}", top_five); + assert_eq!(top_five[0].class_id(), 284); + Ok(()) +} diff --git a/crates/test-programs/src/nn.rs b/crates/test-programs/src/nn.rs index f3b54b460901..361a7c1f6282 100644 --- a/crates/test-programs/src/nn.rs +++ b/crates/test-programs/src/nn.rs @@ -1,39 +1,147 @@ -use anyhow::Result; -use std::time::Instant; -use wasi_nn::{Graph, TensorType}; - -/// Run a wasi-nn inference using a simple classifier model (single input, -/// single output). -pub fn classify(graph: Graph, tensor: Vec) -> Result> { - let mut context = graph.init_execution_context()?; - println!( - "[nn] created wasi-nn execution context with ID: {}", - context - ); - - // Many classifiers have a single input; currently, this test suite also - // uses tensors of the same shape, though this is not usually the case. - context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?; - println!("[nn] set input tensor: {} bytes", tensor.len()); - - let before = Instant::now(); - context.compute()?; - println!( - "[nn] executed graph inference in {} ms", - before.elapsed().as_millis() - ); - - // Many classifiers emit probabilities as floating point values; here we - // convert the raw bytes to `f32` knowing all models used here use that - // type. - let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::()]; - let num_bytes = context.get_output(0, &mut output_buffer)?; - println!("[nn] retrieved output tensor: {} bytes", num_bytes); - let output: Vec = output_buffer[..num_bytes] - .chunks(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect(); - Ok(output) +//! This module attempts to paper over the differences between the two +//! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and +//! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple +//! classifier, this exposes a high-level `classify` function to go along with +//! `load`, etc. +//! +//! This module exists solely for convenience--e.g., reduces test duplication. +//! In the future can be safely disposed of or altered as more tests are added. + +/// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the +/// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests. +pub mod wit { + use anyhow::{anyhow, Result}; + use std::time::Instant; + + // Generate the wasi-nn bindings based on the `*.wit` files. + wit_bindgen::generate!({ + path: "../wasi-nn/wit", + world: "ml", + default_bindings_module: "test_programs::ml" + }); + use self::wasi::nn::errors; + use self::wasi::nn::graph::{self, Graph}; + pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests. + use self::wasi::nn::tensor::{Tensor, TensorType}; + + /// Load a wasi-nn graph from a set of bytes. + pub fn load( + bytes: &[Vec], + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + graph::load(bytes, encoding, target).map_err(err_as_anyhow) + } + + /// Load a wasi-nn graph by name. + pub fn load_by_name(name: &str) -> Result { + graph::load_by_name(name).map_err(err_as_anyhow) + } + + /// Run a wasi-nn inference using a simple classifier model (single input, + /// single output). + pub fn classify(graph: Graph, input: (&str, Vec), output: &str) -> Result> { + let context = graph.init_execution_context().map_err(err_as_anyhow)?; + println!( + "[nn] created wasi-nn execution context with ID: {:?}", + context + ); + + // Many classifiers have a single input; currently, this test suite also + // uses tensors of the same shape, though this is not usually the case. + let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1); + context.set_input(input.0, tensor).map_err(err_as_anyhow)?; + println!("[nn] set input tensor: {} bytes", input.1.len()); + + let before = Instant::now(); + context.compute().map_err(err_as_anyhow)?; + println!( + "[nn] executed graph inference in {} ms", + before.elapsed().as_millis() + ); + + // Many classifiers emit probabilities as floating point values; here we + // convert the raw bytes to `f32` knowing all models used here use that + // type. + let output = context.get_output(output).map_err(err_as_anyhow)?; + println!( + "[nn] retrieved output tensor: {} bytes", + output.data().len() + ); + let output: Vec = output + .data() + .chunks(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + Ok(output) + } + + fn err_as_anyhow(e: errors::Error) -> anyhow::Error { + anyhow!("error: {e:?}") + } +} + +/// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based +/// tooling. This older API has been deprecated for the newer WIT-based API but +/// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs` +/// tests. +pub mod witx { + use anyhow::Result; + use std::time::Instant; + pub use wasi_nn::{ExecutionTarget, GraphEncoding}; + use wasi_nn::{Graph, GraphBuilder, TensorType}; + + /// Load a wasi-nn graph from a set of bytes. + pub fn load( + bytes: &[&[u8]], + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?) + } + + /// Load a wasi-nn graph by name. + pub fn load_by_name( + name: &str, + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> Result { + Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?) + } + + /// Run a wasi-nn inference using a simple classifier model (single input, + /// single output). + pub fn classify(graph: Graph, tensor: Vec) -> Result> { + let mut context = graph.init_execution_context()?; + println!( + "[nn] created wasi-nn execution context with ID: {}", + context + ); + + // Many classifiers have a single input; currently, this test suite also + // uses tensors of the same shape, though this is not usually the case. + context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?; + println!("[nn] set input tensor: {} bytes", tensor.len()); + + let before = Instant::now(); + context.compute()?; + println!( + "[nn] executed graph inference in {} ms", + before.elapsed().as_millis() + ); + + // Many classifiers emit probabilities as floating point values; here we + // convert the raw bytes to `f32` knowing all models used here use that + // type. + let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::()]; + let num_bytes = context.get_output(0, &mut output_buffer)?; + println!("[nn] retrieved output tensor: {} bytes", num_bytes); + let output: Vec = output_buffer[..num_bytes] + .chunks(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + Ok(output) + } } /// Sort some classification probabilities. diff --git a/crates/wasi-nn/Cargo.toml b/crates/wasi-nn/Cargo.toml index 7390c1e33146..a5ace788d03a 100644 --- a/crates/wasi-nn/Cargo.toml +++ b/crates/wasi-nn/Cargo.toml @@ -20,7 +20,11 @@ anyhow = { workspace = true, features = ['std'] } wiggle = { workspace = true, features = ["wasmtime"] } # This dependency is necessary for the WIT-generation macros to work: -wasmtime = { workspace = true, features = ["component-model", "runtime"] } +wasmtime = { workspace = true, features = [ + "component-model", + "runtime", + "std", +] } # These dependencies are necessary for the wasi-nn implementation: tracing = { workspace = true } @@ -29,7 +33,7 @@ openvino = { version = "0.6.0", features = [ "runtime-linking", ], optional = true } -ort = { version = "2.0.0-rc.0", default-features = false, features = [ +ort = { version = "2.0.0-rc.2", default-features = false, features = [ "copy-dylibs", "download-binaries", ], optional = true } @@ -46,16 +50,17 @@ walkdir = { workspace = true } cap-std = { workspace = true } libtest-mimic = { workspace = true } test-programs-artifacts = { workspace = true } -wasi-common = { workspace = true, features = ["sync"] } +wasmtime-wasi = { workspace = true, features = ["preview1"] } wasmtime = { workspace = true, features = ["cranelift"] } +tracing-subscriber = { workspace = true } [features] default = ["openvino", "winml"] -# openvino is available on all platforms, it requires openvino installed. +# OpenVINO is available on all platforms; it requires OpenVINO to be installed. openvino = ["dep:openvino"] -# onnx is available on all platforms. +# ONNX is available on all platforms. onnx = ["dep:ort"] -# winml is only available on Windows 10 1809 and later. +# WinML is only available on Windows 10 1809 and later. winml = ["dep:windows"] [[test]] diff --git a/crates/wasi-nn/src/backend/mod.rs b/crates/wasi-nn/src/backend/mod.rs index e06f9e364925..710c93b980f4 100644 --- a/crates/wasi-nn/src/backend/mod.rs +++ b/crates/wasi-nn/src/backend/mod.rs @@ -3,20 +3,20 @@ //! implementations to maintain backend-specific state between calls. #[cfg(feature = "onnx")] -pub mod onnxruntime; +pub mod onnx; #[cfg(feature = "openvino")] pub mod openvino; #[cfg(all(feature = "winml", target_os = "windows"))] pub mod winml; #[cfg(feature = "onnx")] -use self::onnxruntime::OnnxBackend; +use self::onnx::OnnxBackend; #[cfg(feature = "openvino")] use self::openvino::OpenvinoBackend; #[cfg(all(feature = "winml", target_os = "windows"))] use self::winml::WinMLBackend; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; +use crate::wit::{ExecutionTarget, GraphEncoding, Tensor}; use crate::{Backend, ExecutionContext, Graph}; use std::fs::File; use std::io::Read; @@ -69,9 +69,30 @@ pub trait BackendGraph: Send + Sync { /// A [BackendExecutionContext] performs the actual inference; this is the /// backing implementation for a user-facing execution context. pub trait BackendExecutionContext: Send + Sync { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError>; + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError>; fn compute(&mut self) -> Result<(), BackendError>; - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result; + fn get_output(&mut self, id: Id) -> Result; +} + +/// An identifier for a tensor in a [Graph]. +#[derive(Debug)] +pub enum Id { + Index(u32), + Name(String), +} +impl Id { + pub fn index(&self) -> Option { + match self { + Id::Index(i) => Some(*i), + Id::Name(_) => None, + } + } + pub fn name(&self) -> Option<&str> { + match self { + Id::Index(_) => None, + Id::Name(n) => Some(n), + } + } } /// Errors returned by a backend; [BackendError::BackendAccess] is a catch-all diff --git a/crates/wasi-nn/src/backend/onnx.rs b/crates/wasi-nn/src/backend/onnx.rs new file mode 100644 index 000000000000..b6ae2ddf278e --- /dev/null +++ b/crates/wasi-nn/src/backend/onnx.rs @@ -0,0 +1,338 @@ +//! Implements a `wasi-nn` [`BackendInner`] using ONNX via the `ort` crate. + +use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; +use crate::backend::{read, Id}; +use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; +use crate::{ExecutionContext, Graph}; +use anyhow::Context; +use ort::{inputs, GraphOptimizationLevel, Session}; +use std::path::Path; +use std::sync::{Arc, Mutex}; + +#[derive(Default)] +pub struct OnnxBackend(); +unsafe impl Send for OnnxBackend {} +unsafe impl Sync for OnnxBackend {} + +impl BackendInner for OnnxBackend { + fn encoding(&self) -> GraphEncoding { + GraphEncoding::Onnx + } + + fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result { + if builders.len() != 1 { + return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); + } + + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .commit_from_memory(builders[0])?; + + let box_: Box = + Box::new(OnnxGraph(Arc::new(Mutex::new(session)), target)); + Ok(box_.into()) + } + + fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> { + Some(self) + } +} + +impl BackendFromDir for OnnxBackend { + fn load_from_dir( + &mut self, + path: &Path, + target: ExecutionTarget, + ) -> Result { + let model = read(&path.join("model.onnx"))?; + self.load(&[&model], target) + } +} + +struct OnnxGraph(Arc>, #[allow(dead_code)] ExecutionTarget); +unsafe impl Send for OnnxGraph {} +unsafe impl Sync for OnnxGraph {} + +impl BackendGraph for OnnxGraph { + fn init_execution_context(&self) -> Result { + let session = self.0.lock().unwrap(); + // We need to hold on to the names of the inputs in order for + // `set_input` to work with both indexes and names. Having the + // dimensions and type around is useful for validation but could be + // retrieved from the session. + let mut inputs = vec![]; + for input in &session.inputs { + let shape = Shape::from_onnx_input(input)?; + inputs.push(TensorSlot { + shape, + tensor: None, + }); + } + // We need to keep track of the output shapes since they are used for + // creating the output tensor. + let mut outputs = vec![]; + for output in &session.outputs { + let shape = Shape::from_onnx_output(output)?; + outputs.push(TensorSlot { + shape, + tensor: None, + }); + } + let box_: Box = Box::new(OnnxExecutionContext { + session: self.0.clone(), + inputs, + outputs, + }); + Ok(box_.into()) + } +} + +struct OnnxExecutionContext { + session: Arc>, + inputs: Vec, + outputs: Vec, +} + +unsafe impl Send for OnnxExecutionContext {} +unsafe impl Sync for OnnxExecutionContext {} + +impl OnnxExecutionContext { + /// Helper function for finding the internal index of a tensor by [`Id`]. + fn find(&self, id: Id, list: &[TensorSlot]) -> Result { + let index = match id { + Id::Index(i) => { + let i = i as usize; + if i < list.len() { + i + } else { + return Err(BackendError::BackendAccess(anyhow::anyhow!( + "incorrect tensor index: {i} >= {}", + list.len() + ))); + } + } + Id::Name(n) => list.iter().position(|s| s.shape.name == n).ok_or_else(|| { + BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {n}")) + })?, + }; + Ok(index) + } +} + +impl BackendExecutionContext for OnnxExecutionContext { + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + let index = self.find(id, &self.inputs)?; + let input = &mut self.inputs[index]; + if let Err(e) = input.shape.matches(tensor) { + return Err(e.into()); + } + // Hold the tensor data on the context until `compute` is called. + input.tensor.replace(tensor.clone()); + Ok(()) + } + + fn compute(&mut self) -> Result<(), BackendError> { + let mut session_inputs: Vec> = vec![]; + for i in &self.inputs { + session_inputs.extend(to_input_value(i)?); + } + let session = self.session.lock().unwrap(); + let session_outputs = session.run(session_inputs.as_slice())?; + for i in 0..self.outputs.len() { + // TODO: fix preexisting gap--this only handles f32 tensors. + let raw: (Vec, &[f32]) = session_outputs[i].try_extract_raw_tensor()?; + let f32s = raw.1.to_vec(); + let output = &mut self.outputs[i]; + output.tensor.replace(Tensor { + dimensions: output.shape.dimensions_as_u32()?, + ty: output.shape.ty, + data: f32_vec_to_bytes(f32s), + }); + } + Ok(()) + } + + fn get_output(&mut self, id: Id) -> Result { + let index = self.find(id, &self.outputs)?; + let output = &self.outputs[index]; + if let Some(tensor) = &output.tensor { + Ok(tensor.clone()) + } else { + Err(BackendError::BackendAccess(anyhow::anyhow!( + "missing output tensor: {}; has `compute` been called?", + output.shape.name + ))) + } + } +} + +impl From for BackendError { + fn from(e: ort::Error) -> Self { + BackendError::BackendAccess(e.into()) + } +} + +/// Holds a slot for ONNX session inputs and outputs. +/// +/// TODO: it seems unfortunate that we have to "hold" some extra data per +/// session but in the input case, this is necessary for name-based indexing. +struct TensorSlot { + shape: Shape, + tensor: Option, +} + +/// Describes a tensor in ONNX terms. +struct Shape { + name: String, + dimensions: Vec, + ty: TensorType, +} + +impl Shape { + fn from_onnx_input(input: &ort::Input) -> Result { + let name = input.name.clone(); + let (dimensions, ty) = convert_value_type(&input.input_type)?; + Ok(Self { + name, + dimensions, + ty, + }) + } + + fn from_onnx_output(output: &ort::Output) -> Result { + let name = output.name.clone(); + let (dimensions, ty) = convert_value_type(&output.output_type)?; + Ok(Self { + name, + dimensions, + ty, + }) + } + + fn dimensions_as_u32(&self) -> Result, BackendError> { + self.dimensions + .iter() + .map(|d| if *d == -1 { Ok(1) } else { convert_i64(d) }) + .collect() + } + + fn matches(&self, tensor: &Tensor) -> anyhow::Result<()> { + if self.dimensions.len() != tensor.dimensions.len() { + return Err(anyhow::anyhow!( + "input tensor cardinality does not match model: {:?} != {:?}", + self.dimensions, + tensor.dimensions + )); + } else { + for (&shape_dim, &tensor_dim) in self.dimensions.iter().zip(tensor.dimensions.iter()) { + let tensor_dim = tensor_dim as i64; + if !is_dynamic_dimension(shape_dim) && shape_dim != tensor_dim { + return Err(anyhow::anyhow!( + "input tensor dimensions do not match model: {:?} != {:?}", + self.dimensions, + tensor.dimensions + )); + } + } + } + if self.ty != tensor.ty { + return Err(anyhow::anyhow!( + "input tensor type does not match model: {:?} != {:?}", + self.ty, + tensor.ty + )); + } + Ok(()) + } +} + +fn convert_value_type(vt: &ort::ValueType) -> Result<(Vec, TensorType), BackendError> { + match vt { + ort::ValueType::Tensor { ty, dimensions } => { + let dims = dimensions.clone(); + let ty = (*ty).try_into()?; + Ok((dims, ty)) + } + _ => Err(BackendError::BackendAccess(anyhow::anyhow!( + "unsupported input type: {vt:?}" + ))), + } +} + +fn convert_i64(i: &i64) -> Result { + u32::try_from(*i).map_err(|d| -> BackendError { + anyhow::anyhow!("unable to convert dimension to u32: {d}").into() + }) +} + +impl TryFrom for TensorType { + type Error = BackendError; + fn try_from(ty: ort::TensorElementType) -> Result { + match ty { + ort::TensorElementType::Float32 => Ok(TensorType::Fp32), + ort::TensorElementType::Float64 => Ok(TensorType::Fp64), + ort::TensorElementType::Uint8 => Ok(TensorType::U8), + ort::TensorElementType::Int32 => Ok(TensorType::I32), + ort::TensorElementType::Int64 => Ok(TensorType::I64), + _ => Err(BackendError::BackendAccess(anyhow::anyhow!( + "unsupported tensor type: {ty:?}" + ))), + } + } +} + +fn to_input_value(slot: &TensorSlot) -> Result<[ort::SessionInputValue<'_>; 1], BackendError> { + match &slot.tensor { + Some(tensor) => match tensor.ty { + TensorType::Fp32 => { + let data = bytes_to_f32_vec(tensor.data.to_vec()); + let dimensions = tensor + .dimensions + .iter() + .map(|d| *d as i64) // TODO: fewer conversions + .collect::>(); + Ok(inputs![(dimensions, Arc::new(data.into_boxed_slice()))] + .context("failed to create ONNX session input")?) + } + _ => { + unimplemented!("{:?} not supported by ONNX", tensor.ty); + } + }, + None => { + return Err(BackendError::BackendAccess(anyhow::anyhow!( + "missing input tensor: {}", + slot.shape.name + ))); + } + } +} + +pub fn f32_vec_to_bytes(data: Vec) -> Vec { + let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect(); + let result: Vec = chunks.iter().flatten().copied().collect(); + result +} + +pub fn bytes_to_f32_vec(data: Vec) -> Vec { + let chunks: Vec<&[u8]> = data.chunks(4).collect(); + let v: Vec = chunks + .into_iter() + .map(|c| f32::from_le_bytes(c.try_into().unwrap())) + .collect(); + + v.into_iter().collect() +} + +/// Returns whether the dimension is dynamic. +/// +/// ONNX uses [dimensional variables] (i.e., name strings) to indicate that the +/// value of a tensor dimension is user-defined, not fixed by the model. This is +/// useful for batching up several inference requests, e.g. When `ort` returns a +/// dimension of this kind, though, it uses `-1` to indicate that the dimension +/// is dynamic. +/// +/// [dimensional variables]: +/// https://onnx.ai/onnx/repo-docs/IR.html#static-tensor-shapes +fn is_dynamic_dimension(d: i64) -> bool { + d == -1 +} diff --git a/crates/wasi-nn/src/backend/onnxruntime.rs b/crates/wasi-nn/src/backend/onnxruntime.rs deleted file mode 100644 index bddb03dc9b35..000000000000 --- a/crates/wasi-nn/src/backend/onnxruntime.rs +++ /dev/null @@ -1,149 +0,0 @@ -//! Implements a `wasi-nn` [`BackendInner`] using ONNX via ort. - -use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; -use crate::backend::read; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; -use crate::{ExecutionContext, Graph}; -use ort::{inputs, GraphOptimizationLevel, Session}; -use std::path::Path; -use std::sync::{Arc, Mutex}; - -#[derive(Default)] -pub struct OnnxBackend(); -unsafe impl Send for OnnxBackend {} -unsafe impl Sync for OnnxBackend {} - -impl BackendInner for OnnxBackend { - fn encoding(&self) -> GraphEncoding { - GraphEncoding::Onnx - } - - fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result { - if builders.len() != 1 { - return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); - } - - let session = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_model_from_memory(builders[0])?; - - let box_: Box = - Box::new(ONNXGraph(Arc::new(Mutex::new(session)), target)); - Ok(box_.into()) - } - - fn as_dir_loadable<'a>(&'a mut self) -> Option<&'a mut dyn BackendFromDir> { - Some(self) - } -} - -impl BackendFromDir for OnnxBackend { - fn load_from_dir( - &mut self, - path: &Path, - target: ExecutionTarget, - ) -> Result { - let model = read(&path.join("model.onnx"))?; - self.load(&[&model], target) - } -} - -struct ONNXGraph(Arc>, #[allow(dead_code)] ExecutionTarget); - -unsafe impl Send for ONNXGraph {} -unsafe impl Sync for ONNXGraph {} - -impl BackendGraph for ONNXGraph { - fn init_execution_context(&self) -> Result { - let session = self.0.lock().unwrap(); - let inputs = session.inputs.iter().map(|_| None).collect::>(); - let outputs = session.outputs.iter().map(|_| None).collect::>(); - let box_: Box = Box::new(ONNXExecutionContext { - session: self.0.clone(), - inputs, - outputs, - }); - Ok(box_.into()) - } -} - -struct ONNXExecutionContext { - session: Arc>, - inputs: Vec>, - outputs: Vec>>, -} - -unsafe impl Send for ONNXExecutionContext {} -unsafe impl Sync for ONNXExecutionContext {} - -impl BackendExecutionContext for ONNXExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { - self.inputs[index as usize].replace(tensor.clone()); - Ok(()) - } - - fn compute(&mut self) -> Result<(), BackendError> { - let shaped_inputs: Vec<_> = self - .inputs - .iter() - .enumerate() - .map(|(i, _o)| { - let input = self.inputs[i].as_ref().unwrap(); - let dims = input - .dimensions - .as_slice() - .iter() - .map(|d| *d as i64) - .collect::>(); - match input.tensor_type { - TensorType::Fp32 => { - let data = bytes_to_f32_vec(input.data.to_vec()); - inputs![(dims, Arc::new(data.into_boxed_slice()))].unwrap() - } - _ => { - unimplemented!("{:?} not supported by ONNX", input.tensor_type); - } - } - }) - .flatten() - .collect(); - - let session = self.session.lock().unwrap(); - let res = session.run(shaped_inputs.as_slice())?; - - for i in 0..self.outputs.len() { - let raw: (Vec, &[f32]) = res[i].extract_raw_tensor()?; - let f32s = raw.1.to_vec(); - self.outputs[i].replace(f32_vec_to_bytes(f32s)); - } - Ok(()) - } - - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { - let output = self.outputs[index as usize].as_ref().unwrap(); - destination[..output.len()].copy_from_slice(output); - Ok(output.len() as u32) - } -} - -impl From for BackendError { - fn from(e: ort::Error) -> Self { - BackendError::BackendAccess(e.into()) - } -} - -pub fn f32_vec_to_bytes(data: Vec) -> Vec { - let chunks: Vec<[u8; 4]> = data.into_iter().map(|f| f.to_le_bytes()).collect(); - let result: Vec = chunks.iter().flatten().copied().collect(); - result -} - -pub fn bytes_to_f32_vec(data: Vec) -> Vec { - let chunks: Vec<&[u8]> = data.chunks(4).collect(); - let v: Vec = chunks - .into_iter() - .map(|c| f32::from_le_bytes(c.try_into().unwrap())) - .collect(); - - v.into_iter().collect() -} diff --git a/crates/wasi-nn/src/backend/openvino.rs b/crates/wasi-nn/src/backend/openvino.rs index 65c96629eead..b24f93838cab 100644 --- a/crates/wasi-nn/src/backend/openvino.rs +++ b/crates/wasi-nn/src/backend/openvino.rs @@ -1,9 +1,9 @@ //! Implements a `wasi-nn` [`BackendInner`] using OpenVINO. use super::{ - read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, + read, BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id, }; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; +use crate::wit::{self, ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; use std::path::Path; @@ -99,12 +99,15 @@ impl BackendGraph for OpenvinoGraph { struct OpenvinoExecutionContext(Arc, openvino::InferRequest); impl BackendExecutionContext for OpenvinoExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { - let input_name = self.0.get_input_name(index as usize)?; + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + let input_name = match id { + Id::Index(i) => self.0.get_input_name(i as usize)?, + Id::Name(name) => name, + }; // Construct the blob structure. TODO: there must be some good way to // discover the layout here; `desc` should not have to default to NHWC. - let precision = map_tensor_type_to_precision(tensor.tensor_type); + let precision = map_tensor_type_to_precision(tensor.ty); let dimensions = tensor .dimensions .iter() @@ -123,17 +126,20 @@ impl BackendExecutionContext for OpenvinoExecutionContext { Ok(()) } - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { - let output_name = self.0.get_output_name(index as usize)?; + fn get_output(&mut self, id: Id) -> Result { + let output_name = match id { + Id::Index(i) => self.0.get_output_name(i as usize)?, + Id::Name(name) => name, + }; + let dimensions = vec![]; // TODO: get actual shape + let ty = wit::TensorType::Fp32; // TODO: get actual type. let blob = self.1.get_blob(&output_name)?; - let blob_size = blob.byte_len()?; - if blob_size > destination.len() { - return Err(BackendError::NotEnoughMemory(blob_size)); - } - - // Copy the tensor data into the destination buffer. - destination[..blob_size].copy_from_slice(blob.buffer()?); - Ok(blob_size as u32) + let data = blob.buffer()?.to_vec(); + Ok(Tensor { + dimensions, + ty, + data, + }) } } diff --git a/crates/wasi-nn/src/backend/winml.rs b/crates/wasi-nn/src/backend/winml.rs index e11761f86732..b87510bfc877 100644 --- a/crates/wasi-nn/src/backend/winml.rs +++ b/crates/wasi-nn/src/backend/winml.rs @@ -1,16 +1,27 @@ //! Implements a `wasi-nn` [`BackendInner`] using WinML. - -use super::{BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner}; -use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor}; +//! +//! Note that the [docs.rs] documentation for the `windows` crate does have the +//! right features turned on to read about the functions used; see Microsoft's +//! private documentation instead: [microsoft.github.io/windows-docs-rs]. +//! +//! [docs.rs]: https://docs.rs/windows +//! [microsoft.github.io/windows-docs-rs]: https://microsoft.github.io/windows-docs-rs/doc/windows/AI/MachineLearning + +use crate::backend::{ + BackendError, BackendExecutionContext, BackendFromDir, BackendGraph, BackendInner, Id, +}; +use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType}; use crate::{ExecutionContext, Graph}; use std::{fs::File, io::Read, mem::size_of, path::Path}; use windows::core::{ComInterface, HSTRING}; +use windows::Foundation::Collections::IVectorView; use windows::Storage::Streams::{ DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference, }; use windows::AI::MachineLearning::{ - LearningModel, LearningModelBinding, LearningModelDevice, LearningModelDeviceKind, - LearningModelEvaluationResult, LearningModelSession, TensorFeatureDescriptor, TensorFloat, + ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice, + LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession, + TensorFeatureDescriptor, TensorFloat, }; #[derive(Default)] @@ -94,29 +105,64 @@ impl WinMLExecutionContext { } } +impl WinMLExecutionContext { + /// Helper function for finding the internal index of a tensor by [`Id`]. + fn find( + &self, + id: Id, + list: &IVectorView, + ) -> Result { + let index = match id { + Id::Index(i) => { + if i < list.Size()? { + i + } else { + return Err(BackendError::BackendAccess(anyhow::anyhow!( + "incorrect tensor index: {i} >= {}", + list.Size()? + ))); + } + } + Id::Name(name) => list + .into_iter() + .position(|d| d.Name().unwrap() == name) + .ok_or_else(|| { + BackendError::BackendAccess(anyhow::anyhow!("unknown tensor name: {name}")) + })? as u32, + }; + Ok(index) + } +} + impl BackendExecutionContext for WinMLExecutionContext { - fn set_input(&mut self, index: u32, tensor: &Tensor) -> Result<(), BackendError> { + fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> { + let input_features = self.session.Model()?.InputFeatures()?; + let index = self.find(id, &input_features)?; + let input = input_features.GetAt(index)?; + // TODO: Support other tensor types. Only FP32 is supported right now. - match tensor.tensor_type { + match tensor.ty { crate::wit::types::TensorType::Fp32 => {} _ => unimplemented!(), } - let input = self.session.Model()?.InputFeatures()?.GetAt(index)?; - unsafe { - let data = std::slice::from_raw_parts( + // TODO: this is quite unsafe and probably incorrect--will the slice + // still be around by the time the binding is used?! + let data = unsafe { + std::slice::from_raw_parts( tensor.data.as_ptr() as *const f32, - tensor.data.len() / 4, - ); - - self.binding.Bind( - &input.Name()?, - &TensorFloat::CreateFromArray( - &input.cast::()?.Shape()?, - data, - )?, - )?; - } + tensor.data.len() / size_of::(), + ) + }; + + self.binding.Bind( + &input.Name()?, + &TensorFloat::CreateFromArray( + &input.cast::()?.Shape()?, + data, + )?, + )?; + Ok(()) } @@ -125,33 +171,32 @@ impl BackendExecutionContext for WinMLExecutionContext { Ok(()) } - fn get_output(&mut self, index: u32, destination: &mut [u8]) -> Result { - if self.result.is_none() { + fn get_output(&mut self, id: Id) -> Result { + if let Some(result) = &self.result { + let output_features = self.session.Model()?.OutputFeatures()?; + let index = self.find(id, &output_features)?; + let output = output_features.GetAt(index)?; + // TODO: this only handles FP32! + let tensor = result + .Outputs()? + .Lookup(&output.Name()?)? + .cast::()?; + let dimensions = dimensions_as_u32(&tensor.Shape()?)?; + let view = tensor.GetAsVectorView()?; + let mut data = Vec::with_capacity(view.Size()? as usize * size_of::()); + for f in view.into_iter() { + data.extend(f.to_le_bytes()); + } + Ok(Tensor { + ty: TensorType::Fp32, + dimensions, + data, + }) + } else { return Err(BackendError::BackendAccess(anyhow::Error::msg( "Output is not ready.", ))); } - let output_name = self.session.Model()?.OutputFeatures()?.GetAt(index)?; - let output_name_hstring = output_name.Name()?; - - let vector_view = self - .result - .as_ref() - .unwrap() - .Outputs()? - .Lookup(&output_name_hstring)? - .cast::()? - .GetAsVectorView()?; - let output: Vec = vector_view.into_iter().collect(); - let len_to_copy = output.len() * size_of::(); - unsafe { - destination[..len_to_copy].copy_from_slice(std::slice::from_raw_parts( - output.as_ptr() as *const u8, - len_to_copy, - )); - } - - Ok(len_to_copy as u32) } } @@ -168,3 +213,16 @@ impl From for BackendError { BackendError::BackendAccess(anyhow::Error::new(e)) } } + +fn dimensions_as_u32(dimensions: &IVectorView) -> Result, BackendError> { + dimensions + .into_iter() + .map(|d| if d == -1 { Ok(1) } else { convert_i64(d) }) + .collect() +} + +fn convert_i64(i: i64) -> Result { + u32::try_from(i).map_err(|d| -> BackendError { + anyhow::anyhow!("unable to convert dimension to u32: {d}").into() + }) +} diff --git a/crates/wasi-nn/src/ctx.rs b/crates/wasi-nn/src/ctx.rs deleted file mode 100644 index 40cfb9d53d2d..000000000000 --- a/crates/wasi-nn/src/ctx.rs +++ /dev/null @@ -1,146 +0,0 @@ -//! Implements the host state for the `wasi-nn` API: [WasiNnCtx]. - -use crate::backend::{self, BackendError}; -use crate::wit::types::GraphEncoding; -use crate::{Backend, ExecutionContext, Graph, InMemoryRegistry, Registry}; -use anyhow::anyhow; -use std::{collections::HashMap, hash::Hash, path::Path}; -use thiserror::Error; -use wiggle::GuestError; - -type GraphId = u32; -type GraphExecutionContextId = u32; -type BackendName = String; -type GraphDirectory = String; - -/// Construct an in-memory registry from the available backends and a list of -/// `(, )`. This assumes graphs can be loaded -/// from a local directory, which is a safe assumption currently for the current -/// model types. -pub fn preload( - preload_graphs: &[(BackendName, GraphDirectory)], -) -> anyhow::Result<(impl IntoIterator, Registry)> { - let mut backends = backend::list(); - let mut registry = InMemoryRegistry::new(); - for (kind, path) in preload_graphs { - let kind_ = kind.parse()?; - let backend = backends - .iter_mut() - .find(|b| b.encoding() == kind_) - .ok_or(anyhow!("unsupported backend: {}", kind))? - .as_dir_loadable() - .ok_or(anyhow!("{} does not support directory loading", kind))?; - registry.load(backend, Path::new(path))?; - } - Ok((backends, Registry::from(registry))) -} - -/// Capture the state necessary for calling into the backend ML libraries. -pub struct WasiNnCtx { - pub(crate) backends: HashMap, - pub(crate) registry: Registry, - pub(crate) graphs: Table, - pub(crate) executions: Table, -} - -impl WasiNnCtx { - /// Make a new context from the default state. - pub fn new(backends: impl IntoIterator, registry: Registry) -> Self { - let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); - Self { - backends, - registry, - graphs: Table::default(), - executions: Table::default(), - } - } -} - -/// Possible errors while interacting with [WasiNnCtx]. -#[derive(Debug, Error)] -pub enum WasiNnError { - #[error("backend error")] - BackendError(#[from] BackendError), - #[error("guest error")] - GuestError(#[from] GuestError), - #[error("usage error")] - UsageError(#[from] UsageError), -} - -#[derive(Debug, Error)] -pub enum UsageError { - #[error("Invalid context; has the load function been called?")] - InvalidContext, - #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] - InvalidEncoding(GraphEncoding), - #[error("OpenVINO expects only two buffers (i.e. [ir, weights]), passed: {0}")] - InvalidNumberOfBuilders(u32), - #[error("Invalid graph handle; has it been loaded?")] - InvalidGraphHandle, - #[error("Invalid execution context handle; has it been initialized?")] - InvalidExecutionContextHandle, - #[error("Not enough memory to copy tensor data of size: {0}")] - NotEnoughMemory(u32), - #[error("No graph found with name: {0}")] - NotFound(String), -} - -pub(crate) type WasiNnResult = std::result::Result; - -/// Record handle entries in a table. -pub struct Table { - entries: HashMap, - next_key: u32, -} - -impl Default for Table { - fn default() -> Self { - Self { - entries: HashMap::new(), - next_key: 0, - } - } -} - -impl Table -where - K: Eq + Hash + From + Copy, -{ - pub fn insert(&mut self, value: V) -> K { - let key = self.use_next_key(); - self.entries.insert(key, value); - key - } - - pub fn get(&self, key: K) -> Option<&V> { - self.entries.get(&key) - } - - pub fn get_mut(&mut self, key: K) -> Option<&mut V> { - self.entries.get_mut(&key) - } - - fn use_next_key(&mut self) -> K { - let current = self.next_key; - self.next_key += 1; - K::from(current) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::registry::GraphRegistry; - - #[test] - fn example() { - struct FakeRegistry; - impl GraphRegistry for FakeRegistry { - fn get_mut(&mut self, _: &str) -> Option<&mut Graph> { - None - } - } - - let _ctx = WasiNnCtx::new([], Registry::from(FakeRegistry)); - } -} diff --git a/crates/wasi-nn/src/lib.rs b/crates/wasi-nn/src/lib.rs index 71d089d07489..a9f86f2da5a2 100644 --- a/crates/wasi-nn/src/lib.rs +++ b/crates/wasi-nn/src/lib.rs @@ -1,14 +1,34 @@ -mod ctx; -mod registry; - pub mod backend; -pub use ctx::{preload, WasiNnCtx}; -pub use registry::{GraphRegistry, InMemoryRegistry}; +mod registry; pub mod wit; pub mod witx; +use anyhow::anyhow; +use core::fmt; +pub use registry::{GraphRegistry, InMemoryRegistry}; +use std::path::Path; use std::sync::Arc; +/// Construct an in-memory registry from the available backends and a list of +/// `(, )`. This assumes graphs can be loaded +/// from a local directory, which is a safe assumption currently for the current +/// model types. +pub fn preload(preload_graphs: &[(String, String)]) -> anyhow::Result<(Vec, Registry)> { + let mut backends = backend::list(); + let mut registry = InMemoryRegistry::new(); + for (kind, path) in preload_graphs { + let kind_ = kind.parse()?; + let backend = backends + .iter_mut() + .find(|b| b.encoding() == kind_) + .ok_or(anyhow!("unsupported backend: {}", kind))? + .as_dir_loadable() + .ok_or(anyhow!("{} does not support directory loading", kind))?; + registry.load(backend, Path::new(path))?; + } + Ok((backends, Registry::from(registry))) +} + /// A machine learning backend. pub struct Backend(Box); impl std::ops::Deref for Backend { @@ -43,6 +63,27 @@ impl std::ops::Deref for Graph { } } +/// A host-side tensor. +/// +/// Eventually, this may be defined in each backend as they gain the ability to +/// hold tensors on various devices (TODO: +/// https://github.com/WebAssembly/wasi-nn/pull/70). +#[derive(Clone)] +pub struct Tensor { + dimensions: Vec, + ty: wit::TensorType, + data: Vec, +} +impl fmt::Debug for Tensor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Tensor") + .field("dimensions", &self.dimensions) + .field("ty", &self.ty) + .field("data (bytes)", &self.data.len()) + .finish() + } +} + /// A backend-defined execution context. pub struct ExecutionContext(Box); impl From> for ExecutionContext { diff --git a/crates/wasi-nn/src/registry/in_memory.rs b/crates/wasi-nn/src/registry/in_memory.rs index b008f7f43684..86f81203d646 100644 --- a/crates/wasi-nn/src/registry/in_memory.rs +++ b/crates/wasi-nn/src/registry/in_memory.rs @@ -2,7 +2,7 @@ use super::{Graph, GraphRegistry}; use crate::backend::BackendFromDir; -use crate::wit::types::ExecutionTarget; +use crate::wit::ExecutionTarget; use anyhow::{anyhow, bail}; use std::{collections::HashMap, path::Path}; @@ -37,6 +37,9 @@ impl InMemoryRegistry { } impl GraphRegistry for InMemoryRegistry { + fn get(&self, name: &str) -> Option<&Graph> { + self.0.get(name) + } fn get_mut(&mut self, name: &str) -> Option<&mut Graph> { self.0.get_mut(name) } diff --git a/crates/wasi-nn/src/registry/mod.rs b/crates/wasi-nn/src/registry/mod.rs index 83f88e4dca0e..5f4d959132dc 100644 --- a/crates/wasi-nn/src/registry/mod.rs +++ b/crates/wasi-nn/src/registry/mod.rs @@ -12,5 +12,6 @@ use crate::Graph; pub use in_memory::InMemoryRegistry; pub trait GraphRegistry: Send + Sync { + fn get(&self, name: &str) -> Option<&Graph>; fn get_mut(&mut self, name: &str) -> Option<&mut Graph>; } diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index dbe894357cdf..2f5ff2f7c53e 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -15,8 +15,35 @@ //! [`Backend`]: crate::Backend //! [`types`]: crate::wit::types -use crate::{ctx::UsageError, WasiNnCtx}; -use std::{error::Error, fmt, hash::Hash, str::FromStr}; +use crate::backend::Id; +use crate::{Backend, Registry}; +use std::collections::HashMap; +use std::hash::Hash; +use std::{fmt, str::FromStr}; +use wasmtime::component::{Resource, ResourceTable}; + +/// Capture the state necessary for calling into the backend ML libraries. +pub struct WasiNnCtx { + pub(crate) backends: HashMap, + pub(crate) registry: Registry, +} + +impl WasiNnCtx { + /// Make a new context from the default state. + pub fn new(backends: impl IntoIterator, registry: Registry) -> Self { + let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); + Self { backends, registry } + } +} + +/// A trait for modifying internal wasi-nn state. +/// +/// This follows the pattern used by other WASI proposals (see `wasmtime-wasi`, +/// `wasmtime-wasi-http`). +pub trait WasiNnView { + fn ctx(&mut self) -> &mut WasiNnCtx; + fn table(&mut self) -> &mut ResourceTable; +} /// Generate the traits and types from the `wasi-nn` WIT specification. mod gen_ { @@ -24,126 +51,259 @@ mod gen_ { world: "ml", path: "wit/wasi-nn.wit", trappable_imports: true, + with: { + // Configure all WIT http resources to be defined types in this + // crate to use the `ResourceTable` helper methods. + "wasi:nn/graph/graph": crate::Graph, + "wasi:nn/tensor/tensor": crate::Tensor, + "wasi:nn/inference/graph-execution-context": crate::ExecutionContext, + }, }); } -use gen_::wasi::nn as gen; // Shortcut to the module containing the types we need. +use gen_::wasi::nn::{self as gen}; // Shortcut to the module containing the types we need. // Export the `types` used in this crate as well as `ML::add_to_linker`. pub mod types { use super::gen; - pub use gen::graph::{ExecutionTarget, Graph, GraphEncoding}; + pub use gen::errors::Error; + pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding}; pub use gen::inference::GraphExecutionContext; pub use gen::tensor::{Tensor, TensorType}; } +pub use gen::errors::Error; +pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding}; +pub use gen::inference::GraphExecutionContext; +pub use gen::tensor::{Tensor, TensorData, TensorDimensions, TensorType}; pub use gen_::Ml as ML; -impl gen::graph::Host for WasiNnCtx { - /// Load an opaque sequence of bytes to use for inference. +pub fn add_to_linker(l: &mut wasmtime::component::Linker) -> anyhow::Result<()> +where + T: WasiNnView, +{ + let closure = type_annotate_nn::(|t| WasiNnImpl(t)); + gen::graph::add_to_linker_get_host(l, closure)?; + gen::tensor::add_to_linker_get_host(l, closure)?; + gen::inference::add_to_linker_get_host(l, closure)?; + gen::errors::add_to_linker_get_host(l, closure)?; + Ok(()) +} + +pub struct WasiNnImpl(pub T); + +fn type_annotate_nn(val: F) -> F +where + F: Fn(&mut T) -> WasiNnImpl<&mut T>, +{ + val +} + +impl WasiNnView for WasiNnImpl<&mut T> { + fn ctx(&mut self) -> &mut WasiNnCtx { + self.0.ctx() + } + fn table(&mut self) -> &mut wasmtime::component::ResourceTable { + self.0.table() + } +} + +impl gen::graph::Host for WasiNnImpl<&mut T> { fn load( &mut self, - builders: Vec, - encoding: gen::graph::GraphEncoding, - target: gen::graph::ExecutionTarget, - ) -> wasmtime::Result> { - let graph = if let Some(backend) = self.backends.get_mut(&encoding) { + builders: Vec, + encoding: GraphEncoding, + target: ExecutionTarget, + ) -> wasmtime::Result, Error>> { + use core::result::Result::*; + tracing::debug!("load {encoding:?} {target:?}"); + let result = if let Some(backend) = self.ctx().backends.get_mut(&encoding) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); - backend.load(&slices, target.into())? + match backend.load(&slices, target.into()) { + Ok(graph) => { + let graph = self.0.table().push(graph)?; + Ok(graph) + } + Err(error) => { + tracing::error!("failed to load graph: {error:?}"); + Err(Error::RuntimeError) + } + } } else { - return Err(UsageError::InvalidEncoding(encoding.into()).into()); + Err(Error::InvalidEncoding) }; - let graph_id = self.graphs.insert(graph); - Ok(Ok(graph_id)) + wasmtime::Result::Ok(result) } fn load_by_name( &mut self, name: String, - ) -> wasmtime::Result> { - if let Some(graph) = self.registry.get_mut(&name) { - let graph_id = self.graphs.insert(graph.clone().into()); - Ok(Ok(graph_id)) + ) -> wasmtime::Result, Error>> { + use core::result::Result::*; + tracing::debug!("load by name {name:?}"); + let registry = &self.ctx().registry; + let result = if let Some(graph) = registry.get(&name) { + let graph = graph.clone(); + let graph = self.0.table().push(graph)?; + Ok(graph) } else { - return Err(UsageError::NotFound(name.to_string()).into()); - } + tracing::error!("failed to find graph with name: {name}"); + Err(Error::NotFound) + }; + wasmtime::Result::Ok(result) } } -impl gen::inference::Host for WasiNnCtx { - /// Create an execution instance of a loaded graph. - /// - /// TODO: remove completely? +impl gen::graph::HostGraph for WasiNnImpl<&mut T> { fn init_execution_context( &mut self, - graph_id: gen::graph::Graph, - ) -> wasmtime::Result> { - let exec_context = if let Some(graph) = self.graphs.get(graph_id) { - graph.init_execution_context()? - } else { - return Err(UsageError::InvalidGraphHandle.into()); + graph: Resource, + ) -> wasmtime::Result, Error>> { + use core::result::Result::*; + tracing::debug!("initialize execution context"); + let graph = self.0.table().get(&graph)?; + let result = match graph.init_execution_context() { + Ok(exec_context) => { + let exec_context = self.0.table().push(exec_context)?; + Ok(exec_context) + } + Err(error) => { + tracing::error!("failed to initialize execution context: {error:?}"); + Err(Error::RuntimeError) + } }; + wasmtime::Result::Ok(result) + } - let exec_context_id = self.executions.insert(exec_context); - Ok(Ok(exec_context_id)) + fn drop(&mut self, graph: Resource) -> wasmtime::Result<()> { + self.0.table().delete(graph)?; + Ok(()) } +} - /// Define the inputs to use for inference. +impl gen::inference::HostGraphExecutionContext for WasiNnImpl<&mut T> { fn set_input( &mut self, - exec_context_id: gen::inference::GraphExecutionContext, - index: u32, - tensor: gen::tensor::Tensor, - ) -> wasmtime::Result> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - exec_context.set_input(index, &tensor)?; - Ok(Ok(())) + exec_context: Resource, + name: String, + tensor: Resource, + ) -> wasmtime::Result> { + use core::result::Result::*; + let tensor = self.0.table().get(&tensor)?; + tracing::debug!("set input {name:?}: {tensor:?}"); + let tensor = tensor.clone(); // TODO: avoid copying the tensor + let exec_context = self.0.table().get_mut(&exec_context)?; + let result = if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) { + tracing::error!("failed to set input: {e:?}"); + Err(Error::InvalidArgument) } else { - Err(UsageError::InvalidGraphHandle.into()) - } + Ok(()) + }; + wasmtime::Result::Ok(result) } - /// Compute the inference on the given inputs. - /// - /// TODO: refactor to compute(list) -> result, error> fn compute( &mut self, - exec_context_id: gen::inference::GraphExecutionContext, - ) -> wasmtime::Result> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - exec_context.compute()?; - Ok(Ok(())) - } else { - Err(UsageError::InvalidExecutionContextHandle.into()) - } + exec_context: Resource, + ) -> wasmtime::Result> { + use core::result::Result::*; + let exec_context = self.0.table().get_mut(&exec_context)?; + tracing::debug!("compute"); + let result = match exec_context.compute() { + Ok(()) => Ok(()), + Err(error) => { + tracing::error!("failed to compute: {error:?}"); + Err(Error::RuntimeError) + } + }; + wasmtime::Result::Ok(result) } - /// Extract the outputs after inference. + #[doc = r" Extract the outputs after inference."] fn get_output( &mut self, - exec_context_id: gen::inference::GraphExecutionContext, - index: u32, - ) -> wasmtime::Result> { - if let Some(exec_context) = self.executions.get_mut(exec_context_id) { - // Read the output bytes. TODO: this involves a hard-coded upper - // limit on the tensor size that is necessary because there is no - // way to introspect the graph outputs - // (https://github.com/WebAssembly/wasi-nn/issues/37). - let mut destination = vec![0; 1024 * 1024]; - let bytes_read = exec_context.get_output(index, &mut destination)?; - destination.truncate(bytes_read as usize); - Ok(Ok(destination)) - } else { - Err(UsageError::InvalidGraphHandle.into()) - } + exec_context: Resource, + name: String, + ) -> wasmtime::Result, Error>> { + use core::result::Result::*; + let exec_context = self.0.table().get_mut(&exec_context)?; + tracing::debug!("get output {name:?}"); + let result = match exec_context.get_output(Id::Name(name)) { + Ok(tensor) => { + let tensor = self.0.table().push(tensor)?; + Ok(tensor) + } + Err(error) => { + tracing::error!("failed to get output: {error:?}"); + Err(Error::RuntimeError) + } + }; + wasmtime::Result::Ok(result) + } + + fn drop(&mut self, exec_context: Resource) -> wasmtime::Result<()> { + self.0.table().delete(exec_context)?; + Ok(()) } } -impl gen::errors::Host for WasiNnCtx {} +impl gen::tensor::HostTensor for WasiNnImpl<&mut T> { + fn new( + &mut self, + dimensions: TensorDimensions, + ty: TensorType, + data: TensorData, + ) -> wasmtime::Result> { + let tensor = Tensor { + dimensions, + ty, + data, + }; + let tensor = self.0.table().push(tensor)?; + Ok(tensor) + } -impl gen::tensor::Host for WasiNnCtx {} + fn dimensions(&mut self, tensor: Resource) -> wasmtime::Result { + let tensor = self.0.table().get(&tensor)?; + Ok(tensor.dimensions.clone()) + } + + fn ty(&mut self, tensor: Resource) -> wasmtime::Result { + let tensor = self.0.table().get(&tensor)?; + Ok(tensor.ty) + } + + fn data(&mut self, tensor: Resource) -> wasmtime::Result { + let tensor = self.0.table().get(&tensor)?; + Ok(tensor.data.clone()) + } + + fn drop(&mut self, tensor: Resource) -> wasmtime::Result<()> { + self.0.table().delete(tensor)?; + Ok(()) + } +} + +impl gen::tensor::Host for WasiNnImpl<&mut T> {} +impl gen::errors::Host for WasiNnImpl<&mut T> {} +impl gen::inference::Host for WasiNnImpl<&mut T> {} impl Hash for gen::graph::GraphEncoding { fn hash(&self, state: &mut H) { - core::mem::discriminant(self).hash(state); + self.to_string().hash(state) + } +} + +impl fmt::Display for gen::graph::GraphEncoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use gen::graph::GraphEncoding::*; + match self { + Openvino => write!(f, "openvino"), + Onnx => write!(f, "onnx"), + Pytorch => write!(f, "pytorch"), + Tensorflow => write!(f, "tensorflow"), + Tensorflowlite => write!(f, "tensorflowlite"), + Autodetect => write!(f, "autodetect"), + Ggml => write!(f, "ggml"), + } } } @@ -168,4 +328,4 @@ impl fmt::Display for GraphEncodingParseError { write!(f, "unknown graph encoding: {}", self.0) } } -impl Error for GraphEncodingParseError {} +impl std::error::Error for GraphEncodingParseError {} diff --git a/crates/wasi-nn/src/witx.rs b/crates/wasi-nn/src/witx.rs index 9d78a595ec90..f4f2eab90647 100644 --- a/crates/wasi-nn/src/witx.rs +++ b/crates/wasi-nn/src/witx.rs @@ -13,11 +13,83 @@ //! //! [`types`]: crate::wit::types -use crate::ctx::{UsageError, WasiNnCtx, WasiNnError, WasiNnResult as Result}; -use wiggle::{GuestMemory, GuestPtr}; +use crate::backend::BackendError; +use crate::backend::Id; +use crate::wit::GraphEncoding; +use crate::{Backend, ExecutionContext, Graph, Registry}; +use std::collections::HashMap; +use std::hash::Hash; +use thiserror::Error; +use wiggle::{GuestError, GuestMemory, GuestPtr}; pub use gen::wasi_ephemeral_nn::add_to_linker; +pub(crate) type WasiNnResult = std::result::Result; +type Result = WasiNnResult; +type GraphId = u32; +type GraphExecutionContextId = u32; + +/// Capture the state necessary for calling into the backend ML libraries. +pub struct WasiNnCtx { + pub(crate) backends: HashMap, + pub(crate) registry: Registry, + pub(crate) graphs: Table, + pub(crate) executions: Table, +} + +impl WasiNnCtx { + /// Make a new context from the default state. + pub fn new(backends: impl IntoIterator, registry: Registry) -> Self { + let backends = backends.into_iter().map(|b| (b.encoding(), b)).collect(); + Self { + backends, + registry, + graphs: Table::default(), + executions: Table::default(), + } + } +} + +/// Record handle entries in a table. +pub struct Table { + entries: HashMap, + next_key: u32, +} + +impl Default for Table { + fn default() -> Self { + Self { + entries: HashMap::new(), + next_key: 0, + } + } +} + +impl Table +where + K: Eq + Hash + From + Copy, +{ + pub fn insert(&mut self, value: V) -> K { + let key = self.use_next_key(); + self.entries.insert(key, value); + key + } + + pub fn get(&self, key: K) -> Option<&V> { + self.entries.get(&key) + } + + pub fn get_mut(&mut self, key: K) -> Option<&mut V> { + self.entries.get_mut(&key) + } + + fn use_next_key(&mut self) -> K { + let current = self.next_key; + self.next_key += 1; + K::from(current) + } +} + /// Generate the traits and types from the `wasi-nn` WITX specification. mod gen { use super::*; @@ -42,9 +114,10 @@ mod gen { ) -> anyhow::Result { tracing::debug!("host error: {:?}", e); match e { - WasiNnError::BackendError(_) => unimplemented!(), - WasiNnError::GuestError(_) => unimplemented!(), - WasiNnError::UsageError(_) => unimplemented!(), + WasiNnError::BackendError(_) => Ok(types::NnErrno::RuntimeError), + WasiNnError::GuestError(_) => unimplemented!("guest error conversion"), + WasiNnError::UsageError(_) => Ok(types::NnErrno::UnsupportedOperation), + WasiNnError::NotEnoughMemory(_) => Ok(types::NnErrno::TooLarge), } } } @@ -119,10 +192,10 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { let tensor = crate::wit::types::Tensor { dimensions: memory.to_vec(tensor.dimensions)?, - tensor_type: tensor.type_.into(), + ty: tensor.type_.into(), data: memory.to_vec(tensor.data)?, }; - Ok(exec_context.set_input(index, &tensor)?) + Ok(exec_context.set_input(Id::Index(index), &tensor)?) } else { Err(UsageError::InvalidGraphHandle.into()) } @@ -149,13 +222,19 @@ impl gen::wasi_ephemeral_nn::WasiEphemeralNn for WasiNnCtx { out_buffer_max_size: u32, ) -> Result { if let Some(exec_context) = self.executions.get_mut(exec_context_id.into()) { - let mut destination = memory + let tensor = exec_context.get_output(Id::Index(index))?; + let destination = memory .as_slice_mut(out_buffer.as_array(out_buffer_max_size))? .expect( "cannot use with shared memories; \ see https://github.com/bytecodealliance/wasmtime/issues/5235 (TODO)", ); - Ok(exec_context.get_output(index, &mut destination)?) + if tensor.data.len() > destination.len() { + Err(WasiNnError::NotEnoughMemory(tensor.data.len())) + } else { + destination[..tensor.data.len()].copy_from_slice(&tensor.data); + Ok(tensor.data.len() as u32) + } } else { Err(UsageError::InvalidGraphHandle.into()) } @@ -199,3 +278,28 @@ impl From for crate::wit::types::TensorType { } } } + +/// Possible errors while interacting with [WasiNnCtx]. +#[derive(Debug, Error)] +pub enum WasiNnError { + #[error("backend error")] + BackendError(#[from] BackendError), + #[error("guest error")] + GuestError(#[from] GuestError), + #[error("usage error")] + UsageError(#[from] UsageError), + #[error("not enough memory: requested {0} bytes")] + NotEnoughMemory(usize), +} + +#[derive(Debug, Error)] +pub enum UsageError { + #[error("Only OpenVINO's IR is currently supported, passed encoding: {0:?}")] + InvalidEncoding(GraphEncoding), + #[error("Invalid graph handle; has it been loaded?")] + InvalidGraphHandle, + #[error("Invalid execution context handle; has it been initialized?")] + InvalidExecutionContextHandle, + #[error("No graph found with name: {0}")] + NotFound(String), +} diff --git a/crates/wasi-nn/tests/check/mod.rs b/crates/wasi-nn/tests/check/mod.rs index ffdc099b3009..9f59d38130af 100644 --- a/crates/wasi-nn/tests/check/mod.rs +++ b/crates/wasi-nn/tests/check/mod.rs @@ -1,10 +1,8 @@ -//! This is testing-specific code--it is public only so that it can be -//! accessible both in unit and integration tests. +//! Check that the environment is set up correctly for running tests. //! //! This module checks: -//! - that OpenVINO can be found in the environment -//! - that WinML is available -//! - that some ML model artifacts can be downloaded and cached. +//! - that various backends can be located on the system (see sub-modules) +//! - that certain ML model artifacts can be downloaded and cached. #[allow(unused_imports)] use anyhow::{anyhow, Context, Result}; diff --git a/crates/wasi-nn/tests/exec/mod.rs b/crates/wasi-nn/tests/exec/mod.rs index 23840e7a5d3a..5fd25cbd3cbf 100644 --- a/crates/wasi-nn/tests/exec/mod.rs +++ b/crates/wasi-nn/tests/exec/mod.rs @@ -1,52 +1,6 @@ -use crate::check::artifacts_dir; -use anyhow::Result; -use std::path::Path; -use wasi_common::sync::{Dir, WasiCtxBuilder}; -use wasi_common::WasiCtx; -use wasmtime::{Config, Engine, Linker, Module, Store}; -use wasmtime_wasi_nn::{Backend, InMemoryRegistry, WasiNnCtx}; +//! Provide a Wasmtime embedding for executing wasi-nn test programs. -const PREOPENED_DIR_NAME: &str = "fixture"; +pub mod wit; +pub mod witx; -/// Run a wasi-nn test program. This is modeled after -/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API -/// for file reads. -pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { - let path = Path::new(path); - let engine = Engine::new(&Config::new())?; - let mut linker = Linker::new(&engine); - wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?; - wasi_common::sync::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi)?; - let module = Module::from_file(&engine, path)?; - let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); - let instance = linker.instantiate(&mut store, &module)?; - let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?; - start.call(&mut store, ())?; - Ok(()) -} - -/// The host state for running wasi-nn tests. -struct Ctx { - wasi: WasiCtx, - wasi_nn: WasiNnCtx, -} - -impl Ctx { - fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result { - let preopen_dir = Dir::open_ambient_dir(preopen_dir, cap_std::ambient_authority())?; - let mut builder = WasiCtxBuilder::new(); - builder - .inherit_stdio() - .preopened_dir(preopen_dir, PREOPENED_DIR_NAME)?; - let wasi = builder.build(); - - let mut registry = InMemoryRegistry::new(); - let mobilenet_dir = artifacts_dir(); - if preload_model { - registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; - } - let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); - - Ok(Self { wasi, wasi_nn }) - } -} +pub const PREOPENED_DIR_NAME: &str = "fixture"; diff --git a/crates/wasi-nn/tests/exec/wit.rs b/crates/wasi-nn/tests/exec/wit.rs new file mode 100644 index 000000000000..9372238484f7 --- /dev/null +++ b/crates/wasi-nn/tests/exec/wit.rs @@ -0,0 +1,80 @@ +use super::PREOPENED_DIR_NAME; +use crate::check::artifacts_dir; +use anyhow::{anyhow, Result}; +use std::path::Path; +use wasmtime::component::{Component, Linker, ResourceTable}; +use wasmtime::{Config, Engine, Store}; +use wasmtime_wasi::bindings::sync::Command; +use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder}; +use wasmtime_wasi_nn::{wit::WasiNnCtx, Backend, InMemoryRegistry}; + +/// Run a wasi-nn test program. This is modeled after +/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API for +/// file reads. +pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { + let path = Path::new(path); + let engine = Engine::new(&Config::new())?; + let mut linker = Linker::new(&engine); + wasmtime_wasi_nn::wit::add_to_linker(&mut linker)?; + wasmtime_wasi::add_to_linker_sync(&mut linker)?; + let module = Component::from_file(&engine, path)?; + let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); + let command = Command::instantiate(&mut store, &module, &linker)?; + let result = command.wasi_cli_run().call_run(&mut store)?; + result.map_err(|_| anyhow!("failed to run command")) +} + +/// The host state for running wasi-nn component tests. +struct Ctx { + wasi: WasiCtx, + wasi_nn: WasiNnCtx, + table: ResourceTable, +} + +impl Ctx { + fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result { + let mut builder = WasiCtxBuilder::new(); + builder.inherit_stdio().preopened_dir( + preopen_dir, + PREOPENED_DIR_NAME, + DirPerms::READ, + FilePerms::READ, + )?; + let wasi = builder.build(); + + let mut registry = InMemoryRegistry::new(); + let mobilenet_dir = artifacts_dir(); + if preload_model { + registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; + } + let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); + + let table = ResourceTable::new(); + + Ok(Self { + wasi, + wasi_nn, + table, + }) + } +} + +impl wasmtime_wasi::WasiView for Ctx { + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.wasi + } + + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } +} + +impl wasmtime_wasi_nn::wit::WasiNnView for Ctx { + fn ctx(&mut self) -> &mut WasiNnCtx { + &mut self.wasi_nn + } + + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } +} diff --git a/crates/wasi-nn/tests/exec/witx.rs b/crates/wasi-nn/tests/exec/witx.rs new file mode 100644 index 000000000000..21feea1dd641 --- /dev/null +++ b/crates/wasi-nn/tests/exec/witx.rs @@ -0,0 +1,52 @@ +use super::PREOPENED_DIR_NAME; +use crate::check::artifacts_dir; +use anyhow::Result; +use std::path::Path; +use wasmtime::{Config, Engine, Linker, Module, Store}; +use wasmtime_wasi::{preview1::WasiP1Ctx, DirPerms, FilePerms, WasiCtxBuilder}; +use wasmtime_wasi_nn::{witx::WasiNnCtx, Backend, InMemoryRegistry}; + +/// Run a wasi-nn test program. This is modeled after +/// `crates/wasi/tests/all/main.rs` but still uses the older preview1 API +/// for file reads. +pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { + let path = Path::new(path); + let engine = Engine::new(&Config::new())?; + let mut linker = Linker::new(&engine); + wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut Ctx| &mut s.wasi_nn)?; + wasmtime_wasi::preview1::add_to_linker_sync(&mut linker, |s: &mut Ctx| &mut s.wasi)?; + let module = Module::from_file(&engine, path)?; + let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); + let instance = linker.instantiate(&mut store, &module)?; + let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?; + start.call(&mut store, ())?; + Ok(()) +} + +/// The host state for running wasi-nn tests. +struct Ctx { + wasi: WasiP1Ctx, + wasi_nn: WasiNnCtx, +} + +impl Ctx { + fn new(preopen_dir: &Path, preload_model: bool, mut backend: Backend) -> Result { + let mut builder = WasiCtxBuilder::new(); + builder.inherit_stdio().preopened_dir( + preopen_dir, + PREOPENED_DIR_NAME, + DirPerms::READ, + FilePerms::READ, + )?; + let wasi = builder.build_p1(); + + let mut registry = InMemoryRegistry::new(); + let mobilenet_dir = artifacts_dir(); + if preload_model { + registry.load((backend).as_dir_loadable().unwrap(), &mobilenet_dir)?; + } + let wasi_nn = WasiNnCtx::new([backend.into()], registry.into()); + + Ok(Self { wasi, wasi_nn }) + } +} diff --git a/crates/wasi-nn/tests/fixtures/readme.md b/crates/wasi-nn/tests/fixtures/README.md similarity index 100% rename from crates/wasi-nn/tests/fixtures/readme.md rename to crates/wasi-nn/tests/fixtures/README.md diff --git a/crates/wasi-nn/tests/test-programs.rs b/crates/wasi-nn/tests/test-programs.rs index 6dfb89a90e58..e375063b1cd8 100644 --- a/crates/wasi-nn/tests/test-programs.rs +++ b/crates/wasi-nn/tests/test-programs.rs @@ -23,6 +23,8 @@ use test_programs_artifacts::*; use wasmtime_wasi_nn::{backend, Backend}; fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + if cfg!(miri) { return Ok(()); } @@ -45,7 +47,7 @@ fn main() -> Result<()> { let mut trials = Vec::new(); for program in programs { // Either ignore the test if it cannot run (i.e., downgrade `Fail` to - // `Ignore`) or pre-emptively fail it if `error_on_failed_check` is set. + // `Ignore`) or preemptively fail it if `error_on_failed_check` is set. let (run_test, mut check) = check_test_program(program); if !error_on_failed_check { check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`. @@ -68,103 +70,122 @@ fn main() -> Result<()> { /// Return the test program to run and a check that must pass for the test to /// run. fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) { - use IgnoreCheck::*; match name { - "nn_image_classification" => ( - nn_image_classification, - if !cfg!(target_arch = "x86_64") { - Fail("requires x86_64".into()) - } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { - Fail("requires linux or windows".into()) - } else if let Err(e) = check::openvino::is_installed() { - Fail(e.to_string().into()) - } else { - Run - }, + // Legacy WITX-based tests: + "nn_witx_image_classification_openvino" => ( + nn_witx_image_classification_openvino, + IgnoreCheck::for_openvino(), + ), + "nn_witx_image_classification_openvino_named" => ( + nn_witx_image_classification_openvino_named, + IgnoreCheck::for_openvino(), + ), + "nn_witx_image_classification_onnx" => { + (nn_witx_image_classification_onnx, IgnoreCheck::for_onnx()) + } + "nn_witx_image_classification_winml_named" => ( + nn_witx_image_classification_winml_named, + IgnoreCheck::for_winml(), ), - "nn_image_classification_named" => ( - nn_image_classification_named, - if !cfg!(target_arch = "x86_64") { - Fail("requires x86_64".into()) - } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { - Fail("requires linux or windows or macos".into()) - } else if let Err(e) = check::openvino::is_installed() { - Fail(e.to_string().into()) - } else { - Run - }, + // WIT-based tests: + "nn_wit_image_classification_openvino" => ( + nn_wit_image_classification_openvino, + IgnoreCheck::for_openvino(), ), - "nn_image_classification_onnx" => ( - nn_image_classification_onnx, - #[cfg(feature = "onnx")] - if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") { - Fail("requires x86_64 or aarch64".into()) - } else if !cfg!(target_os = "linux") - && !cfg!(target_os = "windows") - && !cfg!(target_os = "macos") - { - Fail("requires linux, windows, or macos".into()) - } else { - Run - }, - #[cfg(not(feature = "onnx"))] - Ignore("requires the `onnx` feature".into()), + "nn_wit_image_classification_openvino_named" => ( + nn_wit_image_classification_openvino_named, + IgnoreCheck::for_openvino(), ), - "nn_image_classification_winml" => ( - nn_image_classification_winml, - #[cfg(all(feature = "winml", target_os = "windows"))] - if !cfg!(target_arch = "x86_64") { - Fail("requires x86_64".into()) - } else if cfg!(target_os = "windows") { - Fail("requires windows".into()) - } else if let Err(e) = check::winml::is_available() { - Fail(e.to_string().into()) - } else { - Run - }, - #[cfg(not(all(feature = "winml", target_os = "windows")))] - Ignore("requires the `winml` feature on windows".into()), + "nn_wit_image_classification_onnx" => { + (nn_wit_image_classification_onnx, IgnoreCheck::for_onnx()) + } + "nn_wit_image_classification_winml_named" => ( + nn_wit_image_classification_winml_named, + IgnoreCheck::for_winml(), ), _ => panic!("unknown test program: {} (add to this `match`)", name), } } -fn nn_image_classification() -> Result<()> { +fn nn_witx_image_classification_openvino() -> Result<()> { check::openvino::is_installed()?; check::openvino::are_artifacts_available()?; let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION, backend, false) + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO, backend, false) } -fn nn_image_classification_named() -> Result<()> { +fn nn_witx_image_classification_openvino_named() -> Result<()> { check::openvino::is_installed()?; check::openvino::are_artifacts_available()?; let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION_NAMED, backend, true) + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO_NAMED, backend, true) } #[cfg(feature = "onnx")] -fn nn_image_classification_onnx() -> Result<()> { +fn nn_witx_image_classification_onnx() -> Result<()> { check::onnx::are_artifacts_available()?; - let backend = Backend::from(backend::onnxruntime::OnnxBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false) + let backend = Backend::from(backend::onnx::OnnxBackend::default()); + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false) } - #[cfg(not(feature = "onnx"))] -fn nn_image_classification_onnx() -> Result<()> { +fn nn_witx_image_classification_onnx() -> Result<()> { anyhow::bail!("this test requires the `onnx` feature") } #[cfg(all(feature = "winml", target_os = "windows"))] -fn nn_image_classification_winml() -> Result<()> { +fn nn_witx_image_classification_winml_named() -> Result<()> { check::winml::is_available()?; check::onnx::are_artifacts_available()?; let backend = Backend::from(backend::winml::WinMLBackend::default()); - exec::run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false) + exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false) +} +#[cfg(not(all(feature = "winml", target_os = "windows")))] +fn nn_witx_image_classification_winml_named() -> Result<()> { + anyhow::bail!("this test requires the `winml` feature and only runs on windows") } +fn nn_wit_image_classification_openvino() -> Result<()> { + check::openvino::is_installed()?; + check::openvino::are_artifacts_available()?; + let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); + exec::wit::run( + NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_COMPONENT, + backend, + false, + ) +} + +fn nn_wit_image_classification_openvino_named() -> Result<()> { + check::openvino::is_installed()?; + check::openvino::are_artifacts_available()?; + let backend = Backend::from(backend::openvino::OpenvinoBackend::default()); + exec::wit::run( + NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_NAMED_COMPONENT, + backend, + true, + ) +} + +#[cfg(feature = "onnx")] +fn nn_wit_image_classification_onnx() -> Result<()> { + check::onnx::are_artifacts_available()?; + let backend = Backend::from(backend::onnx::OnnxBackend::default()); + exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false) +} +#[cfg(not(feature = "onnx"))] +fn nn_wit_image_classification_onnx() -> Result<()> { + anyhow::bail!("this test requires the `onnx` feature") +} + +#[cfg(all(feature = "winml", target_os = "windows"))] +fn nn_wit_image_classification_winml_named() -> Result<()> { + check::winml::is_available()?; + check::onnx::are_artifacts_available()?; + let backend = Backend::from(backend::winml::WinMLBackend::default()); + exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false) +} #[cfg(not(all(feature = "winml", target_os = "windows")))] -fn nn_image_classification_winml() -> Result<()> { +fn nn_wit_image_classification_winml_named() -> Result<()> { anyhow::bail!("this test requires the `winml` feature and only runs on windows") } @@ -197,3 +218,52 @@ impl IgnoreCheck { matches!(self, IgnoreCheck::Ignore(_)) } } + +/// Some pre-test checks for various backends. +impl IgnoreCheck { + fn for_openvino() -> IgnoreCheck { + use IgnoreCheck::*; + if !cfg!(target_arch = "x86_64") { + Fail("requires x86_64".into()) + } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") { + Fail("requires linux or windows or macos".into()) + } else if let Err(e) = check::openvino::is_installed() { + Fail(e.to_string().into()) + } else { + Run + } + } + + fn for_onnx() -> Self { + use IgnoreCheck::*; + #[cfg(feature = "onnx")] + if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") { + Fail("requires x86_64 or aarch64".into()) + } else if !cfg!(target_os = "linux") + && !cfg!(target_os = "windows") + && !cfg!(target_os = "macos") + { + Fail("requires linux, windows, or macos".into()) + } else { + Run + } + #[cfg(not(feature = "onnx"))] + Ignore("requires the `onnx` feature".into()) + } + + fn for_winml() -> IgnoreCheck { + use IgnoreCheck::*; + #[cfg(all(feature = "winml", target_os = "windows"))] + if !cfg!(target_arch = "x86_64") { + Fail("requires x86_64".into()) + } else if !cfg!(target_os = "windows") { + Fail("requires windows".into()) + } else if let Err(e) = check::winml::is_available() { + Fail(e.to_string().into()) + } else { + Run + } + #[cfg(not(all(feature = "winml", target_os = "windows")))] + Ignore("requires the `winml` feature on windows".into()) + } +} diff --git a/crates/wasi-nn/wit/wasi-nn.wit b/crates/wasi-nn/wit/wasi-nn.wit index 19e3de875d61..b8ffd22e8c04 100644 --- a/crates/wasi-nn/wit/wasi-nn.wit +++ b/crates/wasi-nn/wit/wasi-nn.wit @@ -43,16 +43,18 @@ interface tensor { /// memory--e.g., using row-major ordering--and could perhaps be improved. type tensor-data = list; - record tensor { + resource tensor { + constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data); + // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor // containing a single value, use `[1]` for the tensor dimensions. - dimensions: tensor-dimensions, + dimensions: func() -> tensor-dimensions; // Describe the type of element in the tensor (e.g., `f32`). - tensor-type: tensor-type, + ty: func() -> tensor-type; - // Contains the tensor data. - data: tensor-data, + // Return the tensor data. + data: func() -> tensor-data; } } @@ -61,11 +63,12 @@ interface tensor { interface graph { use errors.{error}; use tensor.{tensor}; + use inference.{graph-execution-context}; /// An execution graph for performing inference (i.e., a model). - /// - /// TODO: replace with `resource` (https://github.com/WebAssembly/wasi-nn/issues/47). - type graph = u32; + resource graph { + init-execution-context: func() -> result; + } /// Describes the encoding of the graph. This allows the API to be implemented by various /// backends that encode (i.e., serialize) their graph IR with different formats. @@ -75,6 +78,7 @@ interface graph { tensorflow, pytorch, tensorflowlite, + ggml, autodetect, } @@ -107,27 +111,25 @@ interface graph { interface inference { use errors.{error}; use tensor.{tensor, tensor-data}; - use graph.{graph}; /// Bind a `graph` to the input and output tensors for an inference. /// - /// TODO: this is no longer necessary in WIT (https://github.com/WebAssembly/wasi-nn/issues/43) - type graph-execution-context = u32; - - /// Create an execution instance of a loaded graph. - init-execution-context: func(graph: graph) -> result; - - /// Define the inputs to use for inference. - set-input: func(ctx: graph-execution-context, index: u32, tensor: tensor) -> result<_, error>; - - /// Compute the inference on the given inputs. - /// - /// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this - /// expectation could be removed as a part of https://github.com/WebAssembly/wasi-nn/issues/43. - compute: func(ctx: graph-execution-context) -> result<_, error>; - - /// Extract the outputs after inference. - get-output: func(ctx: graph-execution-context, index: u32) -> result; + /// TODO: this may no longer be necessary in WIT + /// (https://github.com/WebAssembly/wasi-nn/issues/43) + resource graph-execution-context { + /// Define the inputs to use for inference. + set-input: func(name: string, tensor: tensor) -> result<_, error>; + + /// Compute the inference on the given inputs. + /// + /// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this + /// expectation could be removed as a part of + /// https://github.com/WebAssembly/wasi-nn/issues/43. + compute: func() -> result<_, error>; + + /// Extract the outputs after inference. + get-output: func(name: string) -> result; + } } /// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42) @@ -137,7 +139,8 @@ interface errors { invalid-argument, // Invalid encoding. invalid-encoding, - busy, + // The operation timed out. + timeout, // Runtime Error. runtime-error, // Unsupported operation. diff --git a/src/commands/run.rs b/src/commands/run.rs index 317b91aa22d7..01085fbc361e 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -18,7 +18,7 @@ use wasmtime::{Engine, Func, Module, Store, StoreLimits, Val, ValType}; use wasmtime_wasi::WasiView; #[cfg(feature = "wasi-nn")] -use wasmtime_wasi_nn::WasiNnCtx; +use wasmtime_wasi_nn::wit::WasiNnCtx; #[cfg(feature = "wasi-threads")] use wasmtime_wasi_threads::WasiThreadsCtx; @@ -624,40 +624,27 @@ impl RunCommand { { bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); } - #[cfg(feature = "wasi-nn")] + #[cfg(all(feature = "wasi-nn", feature = "component-model"))] { + let (backends, registry) = self.collect_preloaded_nn_graphs()?; match linker { CliLinker::Core(linker) => { wasmtime_wasi_nn::witx::add_to_linker(linker, |host| { - // This WASI proposal is currently not protected against - // concurrent access--i.e., when wasi-threads is actively - // spawning new threads, we cannot (yet) safely allow access and - // fail if more than one thread has `Arc`-references to the - // context. Once this proposal is updated (as wasi-common has - // been) to allow concurrent access, this `Arc::get_mut` - // limitation can be removed. - Arc::get_mut(host.wasi_nn.as_mut().unwrap()) + Arc::get_mut(host.wasi_nn_witx.as_mut().unwrap()) .expect("wasi-nn is not implemented with multi-threading support") })?; + store.data_mut().wasi_nn_witx = Some(Arc::new( + wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry), + )); } #[cfg(feature = "component-model")] CliLinker::Component(linker) => { - wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| { - Arc::get_mut(host.wasi_nn.as_mut().unwrap()) - .expect("wasi-nn is not implemented with multi-threading support") - })?; + wasmtime_wasi_nn::wit::add_to_linker(linker)?; + store.data_mut().wasi_nn_wit = Some(Arc::new( + wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry), + )); } } - let graphs = self - .run - .common - .wasi - .nn_graph - .iter() - .map(|g| (g.format.clone(), g.dir.clone())) - .collect::>(); - let (backends, registry) = wasmtime_wasi_nn::preload(&graphs)?; - store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new(backends, registry))); } } @@ -767,6 +754,21 @@ impl RunCommand { store.data_mut().preview2_ctx = Some(Arc::new(Mutex::new(ctx))); Ok(()) } + + #[cfg(feature = "wasi-nn")] + fn collect_preloaded_nn_graphs( + &self, + ) -> Result<(Vec, wasmtime_wasi_nn::Registry)> { + let graphs = self + .run + .common + .wasi + .nn_graph + .iter() + .map(|g| (g.format.clone(), g.dir.clone())) + .collect::>(); + wasmtime_wasi_nn::preload(&graphs) + } } #[derive(Default, Clone)] @@ -779,7 +781,10 @@ struct Host { preview2_ctx: Option>>, #[cfg(feature = "wasi-nn")] - wasi_nn: Option>, + wasi_nn_wit: Option>, + #[cfg(feature = "wasi-nn")] + wasi_nn_witx: Option>, + #[cfg(feature = "wasi-threads")] wasi_threads: Option>>, #[cfg(feature = "wasi-http")] @@ -824,6 +829,18 @@ impl wasmtime_wasi_http::types::WasiHttpView for Host { } } +#[cfg(feature = "wasi-nn")] +impl wasmtime_wasi_nn::wit::WasiNnView for Host { + fn ctx(&mut self) -> &mut WasiNnCtx { + let ctx = self.wasi_nn_wit.as_mut().unwrap(); + Arc::get_mut(ctx).expect("wasmtime_wasi_nn is not compatible with threads") + } + + fn table(&mut self) -> &mut wasmtime_wasi::ResourceTable { + self.preview2_ctx().table() + } +} + #[cfg(not(unix))] fn ctx_set_listenfd(num_fd: usize, _builder: &mut WasiCtxBuilder) -> Result { Ok(num_fd) diff --git a/src/commands/serve.rs b/src/commands/serve.rs index 56c2af9d3024..8f9a5136ba62 100644 --- a/src/commands/serve.rs +++ b/src/commands/serve.rs @@ -17,7 +17,7 @@ use wasmtime_wasi_http::io::TokioIo; use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpCtx, WasiHttpView}; #[cfg(feature = "wasi-nn")] -use wasmtime_wasi_nn::WasiNnCtx; +use wasmtime_wasi_nn::wit::WasiNnCtx; struct Host { table: wasmtime::component::ResourceTable, @@ -50,6 +50,17 @@ impl WasiHttpView for Host { } } +#[cfg(feature = "wasi-nn")] +impl wasmtime_wasi_nn::wit::WasiNnView for Host { + fn table(&mut self) -> &mut wasmtime::component::ResourceTable { + &mut self.table + } + + fn ctx(&mut self) -> &mut WasiNnCtx { + self.nn.as_mut().unwrap() + } +} + const DEFAULT_ADDR: std::net::SocketAddr = std::net::SocketAddr::new( std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 8080, @@ -75,15 +86,8 @@ impl ServeCommand { pub fn execute(mut self) -> Result<()> { self.run.common.init_logging()?; - // We force cli errors before starting to listen for connections so then we don't - // accidentally delay them to the first request. - if self.run.common.wasi.nn == Some(true) { - #[cfg(not(feature = "wasi-nn"))] - { - bail!("Cannot enable wasi-nn when the binary is not compiled with this feature."); - } - } - + // We force cli errors before starting to listen for connections so then + // we don't accidentally delay them to the first request. if let Some(Profile::Guest { .. }) = &self.run.profile { bail!("Cannot use the guest profiler with components"); } @@ -99,8 +103,8 @@ impl ServeCommand { bail!("wasi-threads does not support components yet") } - // The serve command requires both wasi-http and the component model, so we enable those by - // default here. + // The serve command requires both wasi-http and the component model, so + // we enable those by default here. if self.run.common.wasi.http.replace(true) == Some(false) { bail!("wasi-http is required for the serve command, and must not be disabled"); } @@ -227,7 +231,7 @@ impl ServeCommand { } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::wit::ML::add_to_linker(linker, |host| host.nn.as_mut().unwrap())?; + wasmtime_wasi_nn::wit::add_to_linker(linker)?; } } From c9100b14849b5a90d35f8b0f18f6e50ef1d8eca1 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Wed, 26 Jun 2024 15:21:48 -0700 Subject: [PATCH 2/5] vet: audit `ort`-related crate updates --- supply-chain/audits.toml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/supply-chain/audits.toml b/supply-chain/audits.toml index e29b1cfc4fd0..7cc1fa51a95a 100644 --- a/supply-chain/audits.toml +++ b/supply-chain/audits.toml @@ -2028,6 +2028,12 @@ criteria = "safe-to-deploy" version = "0.46.0" notes = "one use of unsafe to call windows specific api to get console handle." +[[audits.num-traits]] +who = "Andrew Brown " +criteria = "safe-to-deploy" +version = "0.2.19" +notes = "As advertised: a numeric library. The only `unsafe` is from some float-to-int conversions, which seems expected." + [[audits.num_cpus]] who = "Alex Crichton " criteria = "safe-to-deploy" @@ -2145,12 +2151,24 @@ criteria = "safe-to-deploy" version = "2.0.0-rc.0" notes = "As expected, this crate uses `unsafe` to access the `unsafe` `ort-sys` FFI functions; it also includes several `unsafe` implementations of `Send` for several structures. With the `load-dynamic` feature enabled, this crate will be `libloading` external libraries to call FFI functions. With the `fetch-models` feature enabled, this crate can also download arbitrary models to the local filesystem." +[[audits.ort]] +who = "Andrew Brown " +criteria = "safe-to-deploy" +delta = "2.0.0-rc.0 -> 2.0.0-rc.2" +notes = "Same as previous audit: the crate inherently uses `unsafe` FFI calls for using ONNX through `ort-sys` (e.g., logging C error strings). The changes are relatively uninteresting: a lot of documentation, some `must_use`, and general refactoring due to changes in the underlying API." + [[audits.ort-sys]] who = "Andrew Brown " criteria = "safe-to-deploy" version = "2.0.0-rc.0" notes = "As expected, this crate contains a significant number of `unsafe` definitions to expose the FFI surface of the ONNX libraries. Perhaps surprisingly, it also contains some `unsafe` system calls to locate the user's home directory. Another interesting bit is the `build.rs` script: with the `download-binaries` feature enabled, this script will retrieve and link various ONNX libraries from https://parcel.pyke.io. This seems par for the course with this kind of library, though; the alternative--attempting to find the library on an arbitrary system--can be quite complex." +[[audits.ort-sys]] +who = "Andrew Brown " +criteria = "safe-to-deploy" +delta = "2.0.0-rc.0 -> 2.0.0-rc.2" +notes = "This crate still downloads the ONNX libraries as a part of the `build.rs` script; now with more platform options for pre-built binaries stored in a `dist.txt` file. Otherwise largely unchanged since the previous audit." + [[audits.overload]] who = "Pat Hickey " criteria = "safe-to-deploy" From 71dce15b7b9d44932dbf4b46f8292ed3bab7ce6d Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Thu, 27 Jun 2024 11:16:18 -0700 Subject: [PATCH 3/5] Simplify `WasiNnView` With @alexcrichton's help, this change removes the `trait WasiNnView` and `struct WasiNnImpl` wrapping that the WIT-based implementation used for accessing the host context. Instead, `WasiNnView` is now a `struct` containing the mutable references it needs to make things work. This unwraps one complex layer of abstraction, though it does have the downside that it complicates CLI code to split borrows of `Host`. --- crates/wasi-nn/src/wit.rs | 108 ++++++++++++++----------------- crates/wasi-nn/tests/exec/wit.rs | 15 ++--- src/commands/run.rs | 26 ++++---- src/commands/serve.rs | 16 ++--- 4 files changed, 70 insertions(+), 95 deletions(-) diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index 2f5ff2f7c53e..4c182cb8bcd6 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -36,13 +36,22 @@ impl WasiNnCtx { } } -/// A trait for modifying internal wasi-nn state. +/// A wrapper capturing the needed internal wasi-nn state. /// -/// This follows the pattern used by other WASI proposals (see `wasmtime-wasi`, -/// `wasmtime-wasi-http`). -pub trait WasiNnView { - fn ctx(&mut self) -> &mut WasiNnCtx; - fn table(&mut self) -> &mut ResourceTable; +/// Unlike other WASI proposals (see `wasmtime-wasi`, `wasmtime-wasi-http`), +/// this wrapper is not a `trait` but rather holds the references directly. This +/// remove one layer of abstraction for simplicity only, and could be added back +/// in the future if embedders need more control here. +pub struct WasiNnView<'a> { + ctx: &'a mut WasiNnCtx, + table: &'a mut ResourceTable, +} + +impl<'a> WasiNnView<'a> { + /// Create a new view into the wasi-nn state. + pub fn new(table: &'a mut ResourceTable, ctx: &'a mut WasiNnCtx) -> Self { + Self { ctx, table } + } } /// Generate the traits and types from the `wasi-nn` WIT specification. @@ -76,37 +85,20 @@ pub use gen::inference::GraphExecutionContext; pub use gen::tensor::{Tensor, TensorData, TensorDimensions, TensorType}; pub use gen_::Ml as ML; -pub fn add_to_linker(l: &mut wasmtime::component::Linker) -> anyhow::Result<()> -where - T: WasiNnView, -{ - let closure = type_annotate_nn::(|t| WasiNnImpl(t)); - gen::graph::add_to_linker_get_host(l, closure)?; - gen::tensor::add_to_linker_get_host(l, closure)?; - gen::inference::add_to_linker_get_host(l, closure)?; - gen::errors::add_to_linker_get_host(l, closure)?; +/// Add the WIT-based version of the `wasi-nn` API to a +/// [`wasmtime::component::Linker`]. +pub fn add_to_linker( + l: &mut wasmtime::component::Linker, + f: impl Fn(&mut T) -> WasiNnView<'_> + Send + Sync + Copy + 'static, +) -> anyhow::Result<()> { + gen::graph::add_to_linker_get_host(l, f)?; + gen::tensor::add_to_linker_get_host(l, f)?; + gen::inference::add_to_linker_get_host(l, f)?; + gen::errors::add_to_linker_get_host(l, f)?; Ok(()) } -pub struct WasiNnImpl(pub T); - -fn type_annotate_nn(val: F) -> F -where - F: Fn(&mut T) -> WasiNnImpl<&mut T>, -{ - val -} - -impl WasiNnView for WasiNnImpl<&mut T> { - fn ctx(&mut self) -> &mut WasiNnCtx { - self.0.ctx() - } - fn table(&mut self) -> &mut wasmtime::component::ResourceTable { - self.0.table() - } -} - -impl gen::graph::Host for WasiNnImpl<&mut T> { +impl gen::graph::Host for WasiNnView<'_> { fn load( &mut self, builders: Vec, @@ -115,11 +107,11 @@ impl gen::graph::Host for WasiNnImpl<&mut T> { ) -> wasmtime::Result, Error>> { use core::result::Result::*; tracing::debug!("load {encoding:?} {target:?}"); - let result = if let Some(backend) = self.ctx().backends.get_mut(&encoding) { + let result = if let Some(backend) = self.ctx.backends.get_mut(&encoding) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); match backend.load(&slices, target.into()) { Ok(graph) => { - let graph = self.0.table().push(graph)?; + let graph = self.table.push(graph)?; Ok(graph) } Err(error) => { @@ -139,10 +131,10 @@ impl gen::graph::Host for WasiNnImpl<&mut T> { ) -> wasmtime::Result, Error>> { use core::result::Result::*; tracing::debug!("load by name {name:?}"); - let registry = &self.ctx().registry; + let registry = &self.ctx.registry; let result = if let Some(graph) = registry.get(&name) { let graph = graph.clone(); - let graph = self.0.table().push(graph)?; + let graph = self.table.push(graph)?; Ok(graph) } else { tracing::error!("failed to find graph with name: {name}"); @@ -152,17 +144,17 @@ impl gen::graph::Host for WasiNnImpl<&mut T> { } } -impl gen::graph::HostGraph for WasiNnImpl<&mut T> { +impl gen::graph::HostGraph for WasiNnView<'_> { fn init_execution_context( &mut self, graph: Resource, ) -> wasmtime::Result, Error>> { use core::result::Result::*; tracing::debug!("initialize execution context"); - let graph = self.0.table().get(&graph)?; + let graph = self.table.get(&graph)?; let result = match graph.init_execution_context() { Ok(exec_context) => { - let exec_context = self.0.table().push(exec_context)?; + let exec_context = self.table.push(exec_context)?; Ok(exec_context) } Err(error) => { @@ -174,12 +166,12 @@ impl gen::graph::HostGraph for WasiNnImpl<&mut T> { } fn drop(&mut self, graph: Resource) -> wasmtime::Result<()> { - self.0.table().delete(graph)?; + self.table.delete(graph)?; Ok(()) } } -impl gen::inference::HostGraphExecutionContext for WasiNnImpl<&mut T> { +impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> { fn set_input( &mut self, exec_context: Resource, @@ -187,10 +179,10 @@ impl gen::inference::HostGraphExecutionContext for WasiNnImpl<&mu tensor: Resource, ) -> wasmtime::Result> { use core::result::Result::*; - let tensor = self.0.table().get(&tensor)?; + let tensor = self.table.get(&tensor)?; tracing::debug!("set input {name:?}: {tensor:?}"); let tensor = tensor.clone(); // TODO: avoid copying the tensor - let exec_context = self.0.table().get_mut(&exec_context)?; + let exec_context = self.table.get_mut(&exec_context)?; let result = if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) { tracing::error!("failed to set input: {e:?}"); Err(Error::InvalidArgument) @@ -205,7 +197,7 @@ impl gen::inference::HostGraphExecutionContext for WasiNnImpl<&mu exec_context: Resource, ) -> wasmtime::Result> { use core::result::Result::*; - let exec_context = self.0.table().get_mut(&exec_context)?; + let exec_context = &mut self.table.get_mut(&exec_context)?; tracing::debug!("compute"); let result = match exec_context.compute() { Ok(()) => Ok(()), @@ -224,11 +216,11 @@ impl gen::inference::HostGraphExecutionContext for WasiNnImpl<&mu name: String, ) -> wasmtime::Result, Error>> { use core::result::Result::*; - let exec_context = self.0.table().get_mut(&exec_context)?; + let exec_context = self.table.get_mut(&exec_context)?; tracing::debug!("get output {name:?}"); let result = match exec_context.get_output(Id::Name(name)) { Ok(tensor) => { - let tensor = self.0.table().push(tensor)?; + let tensor = self.table.push(tensor)?; Ok(tensor) } Err(error) => { @@ -240,12 +232,12 @@ impl gen::inference::HostGraphExecutionContext for WasiNnImpl<&mu } fn drop(&mut self, exec_context: Resource) -> wasmtime::Result<()> { - self.0.table().delete(exec_context)?; + self.table.delete(exec_context)?; Ok(()) } } -impl gen::tensor::HostTensor for WasiNnImpl<&mut T> { +impl gen::tensor::HostTensor for WasiNnView<'_> { fn new( &mut self, dimensions: TensorDimensions, @@ -257,34 +249,34 @@ impl gen::tensor::HostTensor for WasiNnImpl<&mut T> { ty, data, }; - let tensor = self.0.table().push(tensor)?; + let tensor = self.table.push(tensor)?; Ok(tensor) } fn dimensions(&mut self, tensor: Resource) -> wasmtime::Result { - let tensor = self.0.table().get(&tensor)?; + let tensor = self.table.get(&tensor)?; Ok(tensor.dimensions.clone()) } fn ty(&mut self, tensor: Resource) -> wasmtime::Result { - let tensor = self.0.table().get(&tensor)?; + let tensor = self.table.get(&tensor)?; Ok(tensor.ty) } fn data(&mut self, tensor: Resource) -> wasmtime::Result { - let tensor = self.0.table().get(&tensor)?; + let tensor = self.table.get(&tensor)?; Ok(tensor.data.clone()) } fn drop(&mut self, tensor: Resource) -> wasmtime::Result<()> { - self.0.table().delete(tensor)?; + self.table.delete(tensor)?; Ok(()) } } -impl gen::tensor::Host for WasiNnImpl<&mut T> {} -impl gen::errors::Host for WasiNnImpl<&mut T> {} -impl gen::inference::Host for WasiNnImpl<&mut T> {} +impl gen::tensor::Host for WasiNnView<'_> {} +impl gen::errors::Host for WasiNnView<'_> {} +impl gen::inference::Host for WasiNnView<'_> {} impl Hash for gen::graph::GraphEncoding { fn hash(&self, state: &mut H) { diff --git a/crates/wasi-nn/tests/exec/wit.rs b/crates/wasi-nn/tests/exec/wit.rs index 9372238484f7..5f2d546d667d 100644 --- a/crates/wasi-nn/tests/exec/wit.rs +++ b/crates/wasi-nn/tests/exec/wit.rs @@ -6,6 +6,7 @@ use wasmtime::component::{Component, Linker, ResourceTable}; use wasmtime::{Config, Engine, Store}; use wasmtime_wasi::bindings::sync::Command; use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder}; +use wasmtime_wasi_nn::wit::WasiNnView; use wasmtime_wasi_nn::{wit::WasiNnCtx, Backend, InMemoryRegistry}; /// Run a wasi-nn test program. This is modeled after @@ -15,7 +16,9 @@ pub fn run(path: &str, backend: Backend, preload_model: bool) -> Result<()> { let path = Path::new(path); let engine = Engine::new(&Config::new())?; let mut linker = Linker::new(&engine); - wasmtime_wasi_nn::wit::add_to_linker(&mut linker)?; + wasmtime_wasi_nn::wit::add_to_linker(&mut linker, |c: &mut Ctx| { + WasiNnView::new(&mut c.table, &mut c.wasi_nn) + })?; wasmtime_wasi::add_to_linker_sync(&mut linker)?; let module = Component::from_file(&engine, path)?; let mut store = Store::new(&engine, Ctx::new(&artifacts_dir(), preload_model, backend)?); @@ -68,13 +71,3 @@ impl wasmtime_wasi::WasiView for Ctx { &mut self.table } } - -impl wasmtime_wasi_nn::wit::WasiNnView for Ctx { - fn ctx(&mut self) -> &mut WasiNnCtx { - &mut self.wasi_nn - } - - fn table(&mut self) -> &mut ResourceTable { - &mut self.table - } -} diff --git a/src/commands/run.rs b/src/commands/run.rs index 01085fbc361e..e5b3b816a28a 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -18,7 +18,7 @@ use wasmtime::{Engine, Func, Module, Store, StoreLimits, Val, ValType}; use wasmtime_wasi::WasiView; #[cfg(feature = "wasi-nn")] -use wasmtime_wasi_nn::wit::WasiNnCtx; +use wasmtime_wasi_nn::wit::WasiNnView; #[cfg(feature = "wasi-threads")] use wasmtime_wasi_threads::WasiThreadsCtx; @@ -639,7 +639,17 @@ impl RunCommand { } #[cfg(feature = "component-model")] CliLinker::Component(linker) => { - wasmtime_wasi_nn::wit::add_to_linker(linker)?; + wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| { + let preview2_ctx = + h.preview2_ctx.as_mut().expect("wasip2 is not configured"); + let preview2_ctx = Arc::get_mut(preview2_ctx) + .expect("wasmtime_wasi is not compatible with threads") + .get_mut() + .unwrap(); + let nn_ctx = Arc::get_mut(h.wasi_nn_wit.as_mut().unwrap()) + .expect("wasi-nn is not implemented with multi-threading support"); + WasiNnView::new(preview2_ctx.table(), nn_ctx) + })?; store.data_mut().wasi_nn_wit = Some(Arc::new( wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry), )); @@ -829,18 +839,6 @@ impl wasmtime_wasi_http::types::WasiHttpView for Host { } } -#[cfg(feature = "wasi-nn")] -impl wasmtime_wasi_nn::wit::WasiNnView for Host { - fn ctx(&mut self) -> &mut WasiNnCtx { - let ctx = self.wasi_nn_wit.as_mut().unwrap(); - Arc::get_mut(ctx).expect("wasmtime_wasi_nn is not compatible with threads") - } - - fn table(&mut self) -> &mut wasmtime_wasi::ResourceTable { - self.preview2_ctx().table() - } -} - #[cfg(not(unix))] fn ctx_set_listenfd(num_fd: usize, _builder: &mut WasiCtxBuilder) -> Result { Ok(num_fd) diff --git a/src/commands/serve.rs b/src/commands/serve.rs index 8f9a5136ba62..8a200ee093f3 100644 --- a/src/commands/serve.rs +++ b/src/commands/serve.rs @@ -50,17 +50,6 @@ impl WasiHttpView for Host { } } -#[cfg(feature = "wasi-nn")] -impl wasmtime_wasi_nn::wit::WasiNnView for Host { - fn table(&mut self) -> &mut wasmtime::component::ResourceTable { - &mut self.table - } - - fn ctx(&mut self) -> &mut WasiNnCtx { - self.nn.as_mut().unwrap() - } -} - const DEFAULT_ADDR: std::net::SocketAddr = std::net::SocketAddr::new( std::net::IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)), 8080, @@ -231,7 +220,10 @@ impl ServeCommand { } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::wit::add_to_linker(linker)?; + wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut Host| { + let ctx = h.nn.as_mut().unwrap(); + wasmtime_wasi_nn::wit::WasiNnView::new(&mut h.table, ctx) + })?; } } From df9b5e1b3a545975a7c45ed3cc8010b82c6d3026 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Thu, 27 Jun 2024 11:22:33 -0700 Subject: [PATCH 4/5] Temporarily disable WIT check --- ci/vendor-wit.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 30c1cbb9283c..53ce9ad80cb2 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -36,5 +36,9 @@ cp -r $dst crates/wasi-http/wit # slightly different than above. repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn revision=e2310b -curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx +# TODO: the in-tree `wasi-nn` implementation does not yet fully support the +# latest WIT specification on `main`. To create a baseline for moving forward, +# the in-tree WIT incorporates some but not all of the upstream changes. This +# TODO can be removed once the implementation catches up with the spec. +# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit From 14df2e26c756ed3f4f565acba76351e769165895 Mon Sep 17 00:00:00 2001 From: Andrew Brown Date: Thu, 27 Jun 2024 11:53:06 -0700 Subject: [PATCH 5/5] Refactor errors to use `trappable_error_type` This change simplifies the return types of the host implementations of the WIT-based wasi-nn. There is more work to be done with errors, e.g., to catch up with the upstream decision to return errors as resources. But this is better than the previous mess. --- crates/wasi-nn/src/wit.rs | 96 ++++++++++++++++++++++++--------------- 1 file changed, 60 insertions(+), 36 deletions(-) diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index 4c182cb8bcd6..40f6fc4c1ff6 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -54,6 +54,31 @@ impl<'a> WasiNnView<'a> { } } +pub enum Error { + /// Caller module passed an invalid argument. + InvalidArgument, + /// Invalid encoding. + InvalidEncoding, + /// The operation timed out. + Timeout, + /// Runtime Error. + RuntimeError, + /// Unsupported operation. + UnsupportedOperation, + /// Graph is too large. + TooLarge, + /// Graph not found. + NotFound, + /// A runtime error occurred that we should trap on; see `StreamError`. + Trap(anyhow::Error), +} + +impl From for Error { + fn from(error: wasmtime::component::ResourceTableError) -> Self { + Self::Trap(error.into()) + } +} + /// Generate the traits and types from the `wasi-nn` WIT specification. mod gen_ { wasmtime::component::bindgen!({ @@ -67,6 +92,9 @@ mod gen_ { "wasi:nn/tensor/tensor": crate::Tensor, "wasi:nn/inference/graph-execution-context": crate::ExecutionContext, }, + trappable_error_type: { + "wasi:nn/errors/error" => super::Error, + }, }); } use gen_::wasi::nn::{self as gen}; // Shortcut to the module containing the types we need. @@ -79,7 +107,6 @@ pub mod types { pub use gen::inference::GraphExecutionContext; pub use gen::tensor::{Tensor, TensorType}; } -pub use gen::errors::Error; pub use gen::graph::{ExecutionTarget, Graph, GraphBuilder, GraphEncoding}; pub use gen::inference::GraphExecutionContext; pub use gen::tensor::{Tensor, TensorData, TensorDimensions, TensorType}; @@ -104,10 +131,9 @@ impl gen::graph::Host for WasiNnView<'_> { builders: Vec, encoding: GraphEncoding, target: ExecutionTarget, - ) -> wasmtime::Result, Error>> { - use core::result::Result::*; + ) -> Result, Error> { tracing::debug!("load {encoding:?} {target:?}"); - let result = if let Some(backend) = self.ctx.backends.get_mut(&encoding) { + if let Some(backend) = self.ctx.backends.get_mut(&encoding) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); match backend.load(&slices, target.into()) { Ok(graph) => { @@ -121,26 +147,21 @@ impl gen::graph::Host for WasiNnView<'_> { } } else { Err(Error::InvalidEncoding) - }; - wasmtime::Result::Ok(result) + } } - fn load_by_name( - &mut self, - name: String, - ) -> wasmtime::Result, Error>> { + fn load_by_name(&mut self, name: String) -> Result, Error> { use core::result::Result::*; tracing::debug!("load by name {name:?}"); let registry = &self.ctx.registry; - let result = if let Some(graph) = registry.get(&name) { + if let Some(graph) = registry.get(&name) { let graph = graph.clone(); let graph = self.table.push(graph)?; Ok(graph) } else { tracing::error!("failed to find graph with name: {name}"); Err(Error::NotFound) - }; - wasmtime::Result::Ok(result) + } } } @@ -148,11 +169,11 @@ impl gen::graph::HostGraph for WasiNnView<'_> { fn init_execution_context( &mut self, graph: Resource, - ) -> wasmtime::Result, Error>> { + ) -> Result, Error> { use core::result::Result::*; tracing::debug!("initialize execution context"); let graph = self.table.get(&graph)?; - let result = match graph.init_execution_context() { + match graph.init_execution_context() { Ok(exec_context) => { let exec_context = self.table.push(exec_context)?; Ok(exec_context) @@ -161,8 +182,7 @@ impl gen::graph::HostGraph for WasiNnView<'_> { tracing::error!("failed to initialize execution context: {error:?}"); Err(Error::RuntimeError) } - }; - wasmtime::Result::Ok(result) + } } fn drop(&mut self, graph: Resource) -> wasmtime::Result<()> { @@ -177,36 +197,29 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> { exec_context: Resource, name: String, tensor: Resource, - ) -> wasmtime::Result> { - use core::result::Result::*; + ) -> Result<(), Error> { let tensor = self.table.get(&tensor)?; tracing::debug!("set input {name:?}: {tensor:?}"); let tensor = tensor.clone(); // TODO: avoid copying the tensor let exec_context = self.table.get_mut(&exec_context)?; - let result = if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) { + if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) { tracing::error!("failed to set input: {e:?}"); Err(Error::InvalidArgument) } else { Ok(()) - }; - wasmtime::Result::Ok(result) + } } - fn compute( - &mut self, - exec_context: Resource, - ) -> wasmtime::Result> { - use core::result::Result::*; + fn compute(&mut self, exec_context: Resource) -> Result<(), Error> { let exec_context = &mut self.table.get_mut(&exec_context)?; tracing::debug!("compute"); - let result = match exec_context.compute() { + match exec_context.compute() { Ok(()) => Ok(()), Err(error) => { tracing::error!("failed to compute: {error:?}"); Err(Error::RuntimeError) } - }; - wasmtime::Result::Ok(result) + } } #[doc = r" Extract the outputs after inference."] @@ -214,11 +227,10 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> { &mut self, exec_context: Resource, name: String, - ) -> wasmtime::Result, Error>> { - use core::result::Result::*; + ) -> Result, Error> { let exec_context = self.table.get_mut(&exec_context)?; tracing::debug!("get output {name:?}"); - let result = match exec_context.get_output(Id::Name(name)) { + match exec_context.get_output(Id::Name(name)) { Ok(tensor) => { let tensor = self.table.push(tensor)?; Ok(tensor) @@ -227,8 +239,7 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> { tracing::error!("failed to get output: {error:?}"); Err(Error::RuntimeError) } - }; - wasmtime::Result::Ok(result) + } } fn drop(&mut self, exec_context: Resource) -> wasmtime::Result<()> { @@ -275,7 +286,20 @@ impl gen::tensor::HostTensor for WasiNnView<'_> { } impl gen::tensor::Host for WasiNnView<'_> {} -impl gen::errors::Host for WasiNnView<'_> {} +impl gen::errors::Host for WasiNnView<'_> { + fn convert_error(&mut self, err: Error) -> wasmtime::Result { + match err { + Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument), + Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding), + Error::Timeout => Ok(gen::errors::Error::Timeout), + Error::RuntimeError => Ok(gen::errors::Error::RuntimeError), + Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation), + Error::TooLarge => Ok(gen::errors::Error::TooLarge), + Error::NotFound => Ok(gen::errors::Error::NotFound), + Error::Trap(e) => Err(e), + } + } +} impl gen::inference::Host for WasiNnView<'_> {} impl Hash for gen::graph::GraphEncoding {