Skip to content

Commit

Permalink
[WASI-NN] Add support for a ONNXruntime backend using ort (#7691)
Browse files Browse the repository at this point in the history
* add an initial implemenation for onnxruntime backend of wasi-nn

Signed-off-by: David Justice <david@devigned.com>

* vet: audit ONNX dependencies

This change is the result of a long slog through the dependencies of the
`ort` library. The only missing dependency is `compact_str`, which needs
further discussion.

* vet: add ONNX audit entry for compact_str 0.7.1

Signed-off-by: David Justice <david@devigned.com>

* refactor tests to break out onnx and openvino

Signed-off-by: David Justice <david@devigned.com>

* mark wasi-nn onnx example as publish false

Signed-off-by: David Justice <david@devigned.com>

* update the ONNX classification example

* do not use wasi-nn onnx feature if riskv or s390

Signed-off-by: David Justice <david@devigned.com>

* prtest:full fix running WASI-NN ONNX tests across arch os

Signed-off-by: David Justice <david@devigned.com>

---------

Signed-off-by: David Justice <david@devigned.com>
Co-authored-by: Andrew Brown <andrew.brown@intel.com>
  • Loading branch information
devigned and abrown authored Mar 13, 2024
1 parent daa7fdf commit d6945bc
Show file tree
Hide file tree
Showing 23 changed files with 11,088 additions and 860 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ jobs:
fi
# Build and test all features
- run: ./ci/run-tests.sh --locked
- run: ./ci/run-tests.sh ${{ matrix.extra_features }} --locked
env:
RUST_BACKTRACE: 1

Expand Down
104 changes: 102 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions ci/build-test-matrix.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ const array = [
"os": "ubuntu-latest",
"name": "Test Linux x86_64",
"filter": "linux-x64",
"isa": "x64"
"isa": "x64",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "ubuntu-latest",
Expand All @@ -57,18 +58,21 @@ const array = [
{
"os": "macos-latest",
"name": "Test macOS x86_64",
"filter": "macos-x64"
"filter": "macos-x64",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "macos-14",
"name": "Test macOS arm64",
"filter": "macos-arm64",
"target": "aarch64-apple-darwin"
"target": "aarch64-apple-darwin",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "windows-latest",
"name": "Test Windows MSVC x86_64",
"filter": "windows-x64"
"filter": "windows-x64",
"extra_features": "--features wasmtime-wasi-nn/onnx"
},
{
"os": "windows-latest",
Expand All @@ -85,7 +89,7 @@ const array = [
"qemu_target": "aarch64-linux-user",
"name": "Test Linux arm64",
"filter": "linux-arm64",
"isa": "aarch64"
"isa": "aarch64",
},
{
"os": "ubuntu-latest",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
// Load model from preloaded directory named "fixtures" which contains a model.[bin|xml] mobilenet model.
let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
.build_from_cache("mobilenet")?;
.build_from_cache("fixtures")?;
println!("Loaded a graph: {:?}", graph);

let mut context = graph.init_execution_context()?;
Expand Down
55 changes: 55 additions & 0 deletions crates/test-programs/src/bin/nn_image_classification_onnx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use anyhow::Result;
use std::fs;
use wasi_nn::*;

pub fn main() -> Result<()> {
let model = fs::read("fixture/model.onnx").unwrap();
println!("[ONNX] Read model, size in bytes: {}", model.len());

let graph =
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?;

let mut context = graph.init_execution_context()?;
println!(
"[ONNX] Created wasi-nn execution context with ID: {}",
context
);

// Prepare WASI-NN tensor - Tensor data is always a bytes vector
// Load a tensor that precisely matches the graph input tensor
let data = fs::read("fixture/tensor.bgr").unwrap();
println!("[ONNX] Read input tensor, size in bytes: {}", data.len());
context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?;

// Execute the inferencing
context.compute()?;
println!("[ONNX] Executed graph inference");

// Retrieve the output.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer[..])?;
println!(
"[ONNX] Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);

Ok(())
}

// Sort the buffer of probabilities. The graph places the match probability for
// each class at the index for that class (e.g. the probability of class 42 is
// placed at buffer[42]). Here we convert to a wrapping InferenceResult and sort
// the results.
fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
let mut results: Vec<InferenceResult> = buffer
.iter()
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
results
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
4 changes: 4 additions & 0 deletions crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ openvino = { version = "0.6.0", features = [
"runtime-linking",
], optional = true }

ort = { version = "2.0.0-rc.0", default-features = false, features = ["copy-dylibs", "download-binaries"], optional = true }

[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
features = [
Expand All @@ -51,5 +53,7 @@ wasmtime = { workspace = true, features = ["cranelift"] }
default = ["openvino"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# onnx is available on all platforms.
onnx = ["dep:ort"]
# winml is only available on Windows 10 1809 and later.
winml = ["dep:windows"]
Loading

0 comments on commit d6945bc

Please sign in to comment.