diff --git a/crates/test-programs/Cargo.toml b/crates/test-programs/Cargo.toml index a17b5631af8f..86854364ff43 100644 --- a/crates/test-programs/Cargo.toml +++ b/crates/test-programs/Cargo.toml @@ -12,6 +12,7 @@ workspace = true [dependencies] anyhow = { workspace = true } wasi = "0.11.0" +wasi-nn = "0.6.0" wit-bindgen = { workspace = true, features = ['default'] } libc = { workspace = true } getrandom = "0.2.9" diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index b97422d3b953..193c2b3134dd 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -65,14 +65,6 @@ fn build_and_generate_tests() { generated_code += &format!("pub const {camel}: &'static str = {wasm:?};\n"); - let adapter = match target.as_str() { - "reactor" => &reactor_adapter, - s if s.starts_with("api_proxy") => &proxy_adapter, - _ => &command_adapter, - }; - let path = compile_component(&wasm, adapter); - generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n"); - // Bucket, based on the name of the test, into a "kind" which generates // a `foreach_*` macro below. let kind = match target.as_str() { @@ -81,6 +73,7 @@ fn build_and_generate_tests() { s if s.starts_with("preview2_") => "preview2", s if s.starts_with("cli_") => "cli", s if s.starts_with("api_") => "api", + s if s.starts_with("nn_") => "nn", // If you're reading this because you hit this panic, either add it // to a test suite above or add a new "suite". The purpose of the // categorization above is to have a static assertion that tests @@ -93,6 +86,18 @@ fn build_and_generate_tests() { if !kind.is_empty() { kinds.entry(kind).or_insert(Vec::new()).push(target); } + + // Generate a component from each test. + if kind == "nn" { + continue; + } + let adapter = match target.as_str() { + "reactor" => &reactor_adapter, + s if s.starts_with("api_proxy") => &proxy_adapter, + _ => &command_adapter, + }; + let path = compile_component(&wasm, adapter); + generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n"); } for (kind, targets) in kinds { diff --git a/crates/test-programs/src/bin/nn_image_classification.rs b/crates/test-programs/src/bin/nn_image_classification.rs new file mode 100644 index 000000000000..f81b89154ed1 --- /dev/null +++ b/crates/test-programs/src/bin/nn_image_classification.rs @@ -0,0 +1,59 @@ +use anyhow::Result; +use std::fs; +use wasi_nn::*; + +pub fn main() -> Result<()> { + let xml = fs::read_to_string("fixture/model.xml").unwrap(); + println!("Read graph XML, first 50 characters: {}", &xml[..50]); + + let weights = fs::read("fixture/model.bin").unwrap(); + println!("Read graph weights, size in bytes: {}", weights.len()); + + let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) + .build_from_bytes([&xml.into_bytes(), &weights])?; + println!("Loaded graph into wasi-nn with ID: {}", graph); + + let mut context = graph.init_execution_context()?; + println!("Created wasi-nn execution context with ID: {}", context); + + // Load a tensor that precisely matches the graph input tensor (see + // `fixture/frozen_inference_graph.xml`). + let data = fs::read("fixture/tensor.bgr").unwrap(); + println!("Read input tensor, size in bytes: {}", data.len()); + context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?; + + // Execute the inference. + context.compute()?; + println!("Executed graph inference"); + + // Retrieve the output. + let mut output_buffer = vec![0f32; 1001]; + context.get_output(0, &mut output_buffer[..])?; + println!( + "Found results, sorted top 5: {:?}", + &sort_results(&output_buffer)[..5] + ); + + Ok(()) +} + +// Sort the buffer of probabilities. The graph places the match probability for +// each class at the index for that class (e.g. the probability of class 42 is +// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort +// the results. It is unclear why the MobileNet output indices are "off by one" +// but the `.skip(1)` below seems necessary to get results that make sense (e.g. +// 763 = "revolver" vs 762 = "restaurant"). +fn sort_results(buffer: &[f32]) -> Vec { + let mut results: Vec = buffer + .iter() + .skip(1) + .enumerate() + .map(|(c, p)| InferenceResult(c, *p)) + .collect(); + results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + results +} + +// A wrapper for class ID and match probabilities. +#[derive(Debug, PartialEq)] +struct InferenceResult(usize, f32); diff --git a/crates/test-programs/src/bin/nn_image_classification_named.rs b/crates/test-programs/src/bin/nn_image_classification_named.rs new file mode 100644 index 000000000000..9e70770efc10 --- /dev/null +++ b/crates/test-programs/src/bin/nn_image_classification_named.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use std::fs; +use wasi_nn::*; + +pub fn main() -> Result<()> { + let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) + .build_from_cache("mobilenet")?; + println!("Loaded a graph: {:?}", graph); + + let mut context = graph.init_execution_context()?; + println!("Created an execution context: {:?}", context); + + // Load a tensor that precisely matches the graph input tensor (see + // `fixture/frozen_inference_graph.xml`). + let tensor_data = fs::read("fixture/tensor.bgr")?; + println!("Read input tensor, size in bytes: {}", tensor_data.len()); + context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)?; + + // Execute the inference. + context.compute()?; + println!("Executed graph inference"); + + // Retrieve the output. + let mut output_buffer = vec![0f32; 1001]; + context.get_output(0, &mut output_buffer[..])?; + + println!( + "Found results, sorted top 5: {:?}", + &sort_results(&output_buffer)[..5] + ); + Ok(()) +} + +// Sort the buffer of probabilities. The graph places the match probability for +// each class at the index for that class (e.g. the probability of class 42 is +// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort +// the results. It is unclear why the MobileNet output indices are "off by one" +// but the `.skip(1)` below seems necessary to get results that make sense (e.g. +// 763 = "revolver" vs 762 = "restaurant"). +fn sort_results(buffer: &[f32]) -> Vec { + let mut results: Vec = buffer + .iter() + .skip(1) + .enumerate() + .map(|(c, p)| InferenceResult(c, *p)) + .collect(); + results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + results +} + +// A wrapper for class ID and match probabilities. +#[derive(Debug, PartialEq)] +struct InferenceResult(usize, f32);