Skip to content

Commit

Permalink
wasi-nn: add test programs
Browse files Browse the repository at this point in the history
This change adds new test programs for wasi-nn in a way fits in with the
existing WASI test infrastructure. The code is not new, though: this
reuses the wasi-nn `examples`, which are currently used by the
`run-wasi-nn-example.sh` CI script. Eventually the examples will be
removed in favor of these tests.

Because wasi-nn's component model support is still in flight, this
change also skips the generation of components for `nn_`-prefixed tests.
  • Loading branch information
abrown committed Dec 13, 2023
1 parent a18752c commit d923123
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 8 deletions.
1 change: 1 addition & 0 deletions crates/test-programs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ workspace = true
[dependencies]
anyhow = { workspace = true }
wasi = "0.11.0"
wasi-nn = "0.6.0"
wit-bindgen = { workspace = true, features = ['default'] }
libc = { workspace = true }
getrandom = "0.2.9"
Expand Down
21 changes: 13 additions & 8 deletions crates/test-programs/artifacts/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,6 @@ fn build_and_generate_tests() {

generated_code += &format!("pub const {camel}: &'static str = {wasm:?};\n");

let adapter = match target.as_str() {
"reactor" => &reactor_adapter,
s if s.starts_with("api_proxy") => &proxy_adapter,
_ => &command_adapter,
};
let path = compile_component(&wasm, adapter);
generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n");

// Bucket, based on the name of the test, into a "kind" which generates
// a `foreach_*` macro below.
let kind = match target.as_str() {
Expand All @@ -81,6 +73,7 @@ fn build_and_generate_tests() {
s if s.starts_with("preview2_") => "preview2",
s if s.starts_with("cli_") => "cli",
s if s.starts_with("api_") => "api",
s if s.starts_with("nn_") => "nn",
// If you're reading this because you hit this panic, either add it
// to a test suite above or add a new "suite". The purpose of the
// categorization above is to have a static assertion that tests
Expand All @@ -93,6 +86,18 @@ fn build_and_generate_tests() {
if !kind.is_empty() {
kinds.entry(kind).or_insert(Vec::new()).push(target);
}

// Generate a component from each test.
if kind == "nn" {
continue;
}
let adapter = match target.as_str() {
"reactor" => &reactor_adapter,
s if s.starts_with("api_proxy") => &proxy_adapter,
_ => &command_adapter,
};
let path = compile_component(&wasm, adapter);
generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n");
}

for (kind, targets) in kinds {
Expand Down
59 changes: 59 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use anyhow::Result;
use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
let xml = fs::read_to_string("fixture/model.xml").unwrap();
println!("Read graph XML, first 50 characters: {}", &xml[..50]);

let weights = fs::read("fixture/model.bin").unwrap();
println!("Read graph weights, size in bytes: {}", weights.len());

let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_bytes([&xml.into_bytes(), &weights])?;
println!("Loaded graph into wasi-nn with ID: {}", graph);

let mut context = graph.init_execution_context()?;
println!("Created wasi-nn execution context with ID: {}", context);

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
let data = fs::read("fixture/tensor.bgr").unwrap();
println!("Read input tensor, size in bytes: {}", data.len());
context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?;

// Execute the inference.
context.compute()?;
println!("Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1001];
context.get_output(0, &mut output_buffer[..])?;
println!(
"Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);

Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for
// each class at the index for that class (e.g. the probability of class 42 is
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort
// the results. It is unclear why the MobileNet output indices are "off by one"
// but the `.skip(1)` below seems necessary to get results that make sense (e.g.
// 763 = "revolver" vs 762 = "restaurant").
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
53 changes: 53 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_named.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use anyhow::Result;
use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context()?;
println!("Created an execution context: {:?}", context);

// Load a tensor that precisely matches the graph input tensor (see
// `fixture/frozen_inference_graph.xml`).
let tensor_data = fs::read("fixture/tensor.bgr")?;
println!("Read input tensor, size in bytes: {}", tensor_data.len());
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)?;

// Execute the inference.
context.compute()?;
println!("Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1001];
context.get_output(0, &mut output_buffer[..])?;

println!(
"Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);
Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for
// each class at the index for that class (e.g. the probability of class 42 is
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort
// the results. It is unclear why the MobileNet output indices are "off by one"
// but the `.skip(1)` below seems necessary to get results that make sense (e.g.
// 763 = "revolver" vs 762 = "restaurant").
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);

0 comments on commit d923123

Please sign in to comment.