-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wasi-nn: adapt to new test infrastructure (#7679)
* wasi-nn: add test programs This change adds new test programs for wasi-nn in a way fits in with the existing WASI test infrastructure. The code is not new, though: this reuses the wasi-nn `examples`, which are currently used by the `run-wasi-nn-example.sh` CI script. Eventually the examples will be removed in favor of these tests. Because wasi-nn's component model support is still in flight, this change also skips the generation of components for `nn_`-prefixed tests. * wasi-nn: add `testing` module This testing-only module has code (i.e., `check_test!`) to check whether OpenVINO and some test artifacts are available. The test artifacts are downloaded and cached if not present, expecting `curl` to be present on the command line (as discussed in the previous version of this, #6895). * wasi-nn: run `nn_*` test programs as integration tests Following the pattern of other WASI crates, this change adds the necessary infrastructure to run the `nn_*` files in `crates/test-programs` (built by `test-program-artifacts`). These tests are only run when two sets of conditions are true: - statically: we only run these tests where we expect OpenVINO to be easy to install and run (e.g., the `cfg_attr` parts) - dynamically: we also only run these tests when the OpenVINO libraries can be located and the model artifacts can be downloaded * ci: install OpenVINO for running wasi-nn tests prtest:full * vet: certify the `wasi-nn` crate * ci: remove wasi-nn test script
- Loading branch information
Showing
13 changed files
with
345 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
use anyhow::Result; | ||
use std::fs; | ||
use wasi_nn::*; | ||
|
||
pub fn main() -> Result<()> { | ||
let xml = fs::read_to_string("fixture/model.xml").unwrap(); | ||
println!("Read graph XML, first 50 characters: {}", &xml[..50]); | ||
|
||
let weights = fs::read("fixture/model.bin").unwrap(); | ||
println!("Read graph weights, size in bytes: {}", weights.len()); | ||
|
||
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) | ||
.build_from_bytes([&xml.into_bytes(), &weights])?; | ||
println!("Loaded graph into wasi-nn with ID: {}", graph); | ||
|
||
let mut context = graph.init_execution_context()?; | ||
println!("Created wasi-nn execution context with ID: {}", context); | ||
|
||
// Load a tensor that precisely matches the graph input tensor (see | ||
// `fixture/frozen_inference_graph.xml`). | ||
let data = fs::read("fixture/tensor.bgr").unwrap(); | ||
println!("Read input tensor, size in bytes: {}", data.len()); | ||
context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?; | ||
|
||
// Execute the inference. | ||
context.compute()?; | ||
println!("Executed graph inference"); | ||
|
||
// Retrieve the output. | ||
let mut output_buffer = vec![0f32; 1001]; | ||
context.get_output(0, &mut output_buffer[..])?; | ||
println!( | ||
"Found results, sorted top 5: {:?}", | ||
&sort_results(&output_buffer)[..5] | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
// Sort the buffer of probabilities. The graph places the match probability for | ||
// each class at the index for that class (e.g. the probability of class 42 is | ||
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort | ||
// the results. It is unclear why the MobileNet output indices are "off by one" | ||
// but the `.skip(1)` below seems necessary to get results that make sense (e.g. | ||
// 763 = "revolver" vs 762 = "restaurant"). | ||
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> { | ||
let mut results: Vec<InferenceResult> = buffer | ||
.iter() | ||
.skip(1) | ||
.enumerate() | ||
.map(|(c, p)| InferenceResult(c, *p)) | ||
.collect(); | ||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); | ||
results | ||
} | ||
|
||
// A wrapper for class ID and match probabilities. | ||
#[derive(Debug, PartialEq)] | ||
struct InferenceResult(usize, f32); |
53 changes: 53 additions & 0 deletions
53
crates/test-programs/src/bin/nn_image_classification_named.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
use anyhow::Result; | ||
use std::fs; | ||
use wasi_nn::*; | ||
|
||
pub fn main() -> Result<()> { | ||
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU) | ||
.build_from_cache("mobilenet")?; | ||
println!("Loaded a graph: {:?}", graph); | ||
|
||
let mut context = graph.init_execution_context()?; | ||
println!("Created an execution context: {:?}", context); | ||
|
||
// Load a tensor that precisely matches the graph input tensor (see | ||
// `fixture/frozen_inference_graph.xml`). | ||
let tensor_data = fs::read("fixture/tensor.bgr")?; | ||
println!("Read input tensor, size in bytes: {}", tensor_data.len()); | ||
context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)?; | ||
|
||
// Execute the inference. | ||
context.compute()?; | ||
println!("Executed graph inference"); | ||
|
||
// Retrieve the output. | ||
let mut output_buffer = vec![0f32; 1001]; | ||
context.get_output(0, &mut output_buffer[..])?; | ||
|
||
println!( | ||
"Found results, sorted top 5: {:?}", | ||
&sort_results(&output_buffer)[..5] | ||
); | ||
Ok(()) | ||
} | ||
|
||
// Sort the buffer of probabilities. The graph places the match probability for | ||
// each class at the index for that class (e.g. the probability of class 42 is | ||
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort | ||
// the results. It is unclear why the MobileNet output indices are "off by one" | ||
// but the `.skip(1)` below seems necessary to get results that make sense (e.g. | ||
// 763 = "revolver" vs 762 = "restaurant"). | ||
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> { | ||
let mut results: Vec<InferenceResult> = buffer | ||
.iter() | ||
.skip(1) | ||
.enumerate() | ||
.map(|(c, p)| InferenceResult(c, *p)) | ||
.collect(); | ||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); | ||
results | ||
} | ||
|
||
// A wrapper for class ID and match probabilities. | ||
#[derive(Debug, PartialEq)] | ||
struct InferenceResult(usize, f32); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.