Skip to content

Commit

Permalink
WinML backend for wasi-nn (#7807)
Browse files Browse the repository at this point in the history
* Add WinML backend for wasi-nn.

* Log execution time.

* WinML backend supports execution target selection.

ExecutionTarget::Gpu is mapped to LearningModelDeviceKind::DirectX.

* Limit WinML backend on Windows only.

* Move wasi-nn WinML example to new test infra.

* Scale tensor data in test app.

App knows input and target data range, so it's better to let app to
handle scaling.

* Remove old example for wasi-nn WinML backend.

* Update image2tensor link.

* Format code.

* Upgrade image2tensor to 0.3.1.

* Upgrade windows to 0.52.0

* Use tensor data as input for wasi-nn WinML backend test.

To avoid involving too many external dependencies, input image is
converted to tensor data offline.

* Restore trailing new line for Cargo.toml.

* Remove unnecessary features for windows crate.

* Check input tensor types.

Only FP32 is supported right now. Reject other tensor types.

* Rename default model name to model.onnx.

It aligns with openvino backend.

prtest:full

* Run nn_image_classification_winml only when winml is enabled.

* vet: add trusted `windows` crate to lockfile

* Fix wasi-nn tests when both openvino and winml are enabled.

* Add check for WinML availability.

* vet: reapply vet lock

---------

Co-authored-by: Andrew Brown <andrew.brown@intel.com>
  • Loading branch information
jianjunz and abrown authored Feb 29, 2024
1 parent 8eab5f8 commit 10a9f9e
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 39 deletions.
70 changes: 45 additions & 25 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_winml.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use anyhow::Result;
use std::fs;
use std::time::Instant;
use wasi_nn::*;

pub fn main() -> Result<()> {
// Graph is supposed to be preloaded by `nn-graph` argument. The path ends with "mobilenet".
let graph =
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)
.build_from_cache("mobilenet")
.unwrap();

let mut context = graph.init_execution_context().unwrap();
println!("Created an execution context.");

// Convert image to tensor data.
let tensor_data = fs::read("fixture/kitten.rgb")?;
context
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
.unwrap();

// Execute the inference.
let before_compute = Instant::now();
context.compute().unwrap();
println!(
"Executed graph inference, took {} ms.",
before_compute.elapsed().as_millis()
);

// Retrieve the output.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer[..]).unwrap();

let result = sort_results(&output_buffer);
println!("Found results, sorted top 5: {:?}", &result[..5]);
assert_eq!(result[0].0, 284);
Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for each class at the
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
// (e.g. 763 = "revolver" vs 762 = "restaurant")
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
20 changes: 19 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,19 @@ wasmtime = { workspace = true, features = ["component-model", "runtime"] }

# These dependencies are necessary for the wasi-nn implementation:
tracing = { workspace = true }
openvino = { version = "0.6.0", features = ["runtime-linking"] }
thiserror = { workspace = true }
openvino = { version = "0.6.0", features = [
"runtime-linking",
], optional = true }

[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
features = [
"AI_MachineLearning",
"Storage_Streams",
"Foundation_Collections",
]
optional = true

[build-dependencies]
walkdir = { workspace = true }
Expand All @@ -35,3 +46,10 @@ cap-std = { workspace = true }
test-programs-artifacts = { workspace = true }
wasi-common = { workspace = true, features = ["sync"] }
wasmtime = { workspace = true, features = ["cranelift"] }

[features]
default = ["openvino"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# winml is only available on Windows 10 1809 and later.
winml = ["dep:windows"]
17 changes: 16 additions & 1 deletion crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
//! implementations to maintain backend-specific state between calls.
#[cfg(feature = "openvino")]
pub mod openvino;
#[cfg(feature = "winml")]
pub mod winml;

#[cfg(feature = "openvino")]
use self::openvino::OpenvinoBackend;
#[cfg(feature = "winml")]
use self::winml::WinMLBackend;
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
use crate::{Backend, ExecutionContext, Graph};
use std::path::Path;
Expand All @@ -13,7 +19,16 @@ use wiggle::GuestError;

/// Return a list of all available backend frameworks.
pub fn list() -> Vec<crate::Backend> {
vec![Backend::from(OpenvinoBackend::default())]
let mut backends = vec![];
#[cfg(feature = "openvino")]
{
backends.push(Backend::from(OpenvinoBackend::default()));
}
#[cfg(feature = "winml")]
{
backends.push(Backend::from(WinMLBackend::default()));
}
backends
}

/// A [Backend] contains the necessary state to load [Graph]s.
Expand Down
Loading

0 comments on commit 10a9f9e

Please sign in to comment.