-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add WinML backend for wasi-nn. * Log execution time. * WinML backend supports execution target selection. ExecutionTarget::Gpu is mapped to LearningModelDeviceKind::DirectX. * Limit WinML backend on Windows only. * Move wasi-nn WinML example to new test infra. * Scale tensor data in test app. App knows input and target data range, so it's better to let app to handle scaling. * Remove old example for wasi-nn WinML backend. * Update image2tensor link. * Format code. * Upgrade image2tensor to 0.3.1. * Upgrade windows to 0.52.0 * Use tensor data as input for wasi-nn WinML backend test. To avoid involving too many external dependencies, input image is converted to tensor data offline. * Restore trailing new line for Cargo.toml. * Remove unnecessary features for windows crate. * Check input tensor types. Only FP32 is supported right now. Reject other tensor types. * Rename default model name to model.onnx. It aligns with openvino backend. prtest:full * Run nn_image_classification_winml only when winml is enabled. * vet: add trusted `windows` crate to lockfile * Fix wasi-nn tests when both openvino and winml are enabled. * Add check for WinML availability. * vet: reapply vet lock --------- Co-authored-by: Andrew Brown <andrew.brown@intel.com>
- Loading branch information
Showing
9 changed files
with
459 additions
and
39 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
58 changes: 58 additions & 0 deletions
58
crates/test-programs/src/bin/nn_image_classification_winml.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
use anyhow::Result; | ||
use std::fs; | ||
use std::time::Instant; | ||
use wasi_nn::*; | ||
|
||
pub fn main() -> Result<()> { | ||
// Graph is supposed to be preloaded by `nn-graph` argument. The path ends with "mobilenet". | ||
let graph = | ||
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU) | ||
.build_from_cache("mobilenet") | ||
.unwrap(); | ||
|
||
let mut context = graph.init_execution_context().unwrap(); | ||
println!("Created an execution context."); | ||
|
||
// Convert image to tensor data. | ||
let tensor_data = fs::read("fixture/kitten.rgb")?; | ||
context | ||
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data) | ||
.unwrap(); | ||
|
||
// Execute the inference. | ||
let before_compute = Instant::now(); | ||
context.compute().unwrap(); | ||
println!( | ||
"Executed graph inference, took {} ms.", | ||
before_compute.elapsed().as_millis() | ||
); | ||
|
||
// Retrieve the output. | ||
let mut output_buffer = vec![0f32; 1000]; | ||
context.get_output(0, &mut output_buffer[..]).unwrap(); | ||
|
||
let result = sort_results(&output_buffer); | ||
println!("Found results, sorted top 5: {:?}", &result[..5]); | ||
assert_eq!(result[0].0, 284); | ||
Ok(()) | ||
} | ||
|
||
// Sort the buffer of probabilities. The graph places the match probability for each class at the | ||
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert | ||
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output | ||
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense | ||
// (e.g. 763 = "revolver" vs 762 = "restaurant") | ||
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> { | ||
let mut results: Vec<InferenceResult> = buffer | ||
.iter() | ||
.skip(1) | ||
.enumerate() | ||
.map(|(c, p)| InferenceResult(c, *p)) | ||
.collect(); | ||
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); | ||
results | ||
} | ||
|
||
// A wrapper for class ID and match probabilities. | ||
#[derive(Debug, PartialEq)] | ||
struct InferenceResult(usize, f32); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.