Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WinML backend for wasi-nn #7807

Merged
merged 21 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_winml.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use anyhow::Result;
use std::fs;
use std::time::Instant;
use wasi_nn::*;

pub fn main() -> Result<()> {
// Graph is supposed to be preloaded by `nn-graph` argument. The path ends with "mobilenet".
let graph =
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)
.build_from_cache("mobilenet")
.unwrap();

let mut context = graph.init_execution_context().unwrap();
println!("Created an execution context.");

// Convert image to tensor data.
let tensor_data = fs::read("fixture/kitten.rgb")?;
context
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
.unwrap();

// Execute the inference.
let before_compute = Instant::now();
context.compute().unwrap();
println!(
"Executed graph inference, took {} ms.",
before_compute.elapsed().as_millis()
);

// Retrieve the output.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer[..]).unwrap();

let result = sort_results(&output_buffer);
println!("Found results, sorted top 5: {:?}", &result[..5]);
assert_eq!(result[0].0, 284);
Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for each class at the
// index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
// to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
// indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
// (e.g. 763 = "revolver" vs 762 = "restaurant")
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
20 changes: 19 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,19 @@ wasmtime = { workspace = true, features = ["component-model", "runtime"] }

# These dependencies are necessary for the wasi-nn implementation:
tracing = { workspace = true }
openvino = { version = "0.6.0", features = ["runtime-linking"] }
thiserror = { workspace = true }
openvino = { version = "0.6.0", features = [
"runtime-linking",
], optional = true }

[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
features = [
"AI_MachineLearning",
"Storage_Streams",
"Foundation_Collections",
abrown marked this conversation as resolved.
Show resolved Hide resolved
]
optional = true

[build-dependencies]
walkdir = { workspace = true }
Expand All @@ -35,3 +46,10 @@ cap-std = { workspace = true }
test-programs-artifacts = { workspace = true }
wasi-common = { workspace = true, features = ["sync"] }
wasmtime = { workspace = true, features = ["cranelift"] }

[features]
default = ["openvino"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# winml is only available on Windows 10 1809 and later.
winml = ["dep:windows"]
17 changes: 16 additions & 1 deletion crates/wasi-nn/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
//! this crate. The `Box<dyn ...>` types returned by these interfaces allow
//! implementations to maintain backend-specific state between calls.

#[cfg(feature = "openvino")]
pub mod openvino;
#[cfg(feature = "winml")]
pub mod winml;

#[cfg(feature = "openvino")]
use self::openvino::OpenvinoBackend;
#[cfg(feature = "winml")]
use self::winml::WinMLBackend;
use crate::wit::types::{ExecutionTarget, GraphEncoding, Tensor};
use crate::{Backend, ExecutionContext, Graph};
use std::path::Path;
Expand All @@ -13,7 +19,16 @@ use wiggle::GuestError;

/// Return a list of all available backend frameworks.
pub fn list() -> Vec<crate::Backend> {
vec![Backend::from(OpenvinoBackend::default())]
let mut backends = vec![];
#[cfg(feature = "openvino")]
{
backends.push(Backend::from(OpenvinoBackend::default()));
}
#[cfg(feature = "winml")]
{
backends.push(Backend::from(WinMLBackend::default()));
}
backends
}

/// A [Backend] contains the necessary state to load [Graph]s.
Expand Down
Loading
Loading