Skip to content

Commit

Permalink
Use tensor data as input for wasi-nn WinML backend test.
Browse files Browse the repository at this point in the history
To avoid involving too many external dependencies, input image is
converted to tensor data offline.
  • Loading branch information
jianjunz committed Jan 31, 2024
1 parent a9420c4 commit 1262255
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 137 deletions.
110 changes: 0 additions & 110 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions crates/test-programs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,4 @@ getrandom = "0.2.9"
futures = { workspace = true, default-features = false, features = ['alloc'] }
url = { workspace = true }
sha2 = "0.10.2"
base64 = "0.21.0"
image2tensor = "0.3.1"
base64 = "0.21.0"
26 changes: 3 additions & 23 deletions crates/test-programs/src/bin/nn_image_classification_winml.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use std::fs;
use std::time::Instant;
use wasi_nn::*;

Expand All @@ -13,16 +14,7 @@ pub fn main() -> Result<()> {
println!("Created an execution context.");

// Convert image to tensor data.
let mut tensor_data = image2tensor::convert_image_to_planar_tensor_bytes(
"fixture/kitten.png",
224,
224,
image2tensor::TensorType::F32,
image2tensor::ColorOrder::RGB,
)
.unwrap();
// The model requires values in the range of [0, 1].
scale(&mut tensor_data);
let tensor_data = fs::read("fixture/kitten.rgb")?;
context
.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
.unwrap();
Expand All @@ -40,9 +32,8 @@ pub fn main() -> Result<()> {
context.get_output(0, &mut output_buffer[..]).unwrap();

let result = sort_results(&output_buffer);
assert_eq!(result[0].0, 284);

println!("Found results, sorted top 5: {:?}", &result[..5]);
assert_eq!(result[0].0, 284);
Ok(())
}

Expand All @@ -62,17 +53,6 @@ fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
results
}

// Convert values from [0, 255] to [0, 1].
fn scale(buffer: &mut Vec<u8>) {
const F32_LEN: usize = 4;
for i in (0..buffer.len()).step_by(F32_LEN) {
let mut num = f32::from_ne_bytes(buffer[i..i + F32_LEN].try_into().unwrap());
num /= 225.0;
let num_vec = num.to_ne_bytes().to_vec();
buffer.splice(i..i + F32_LEN, num_vec);
}
}

// A wrapper for class ID and match probabilities.
#[derive(Debug, PartialEq)]
struct InferenceResult(usize, f32);
9 changes: 7 additions & 2 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,20 @@ fn check_winml_artifacts_are_available() -> Result<()> {
fs::create_dir(&artifacts_dir)?;
}
const MODEL_URL: &str="https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/mobilenet/model/mobilenetv2-12.onnx";
const IMG_URL: &str="https://github.com/microsoft/Windows-Machine-Learning/blob/master/SharedContent/media/kitten_224.png?raw=true";
for (from, to) in [(MODEL_URL, "mobilenet.onnx"), (IMG_URL, "kitten.png")] {
for (from, to) in [(MODEL_URL, "mobilenet.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
// kitten.rgb is converted from https://github.com/microsoft/Windows-Machine-Learning/blob/master/SharedContent/media/kitten_224.png?raw=true.
let tensor_path = env::current_dir()?
.join("tests")
.join("fixtures")
.join("kitten.rgb");
fs::copy(tensor_path, artifacts_dir.join("kitten.rgb"))?;
Ok(())
}

Expand Down
Binary file added crates/wasi-nn/tests/fixtures/kitten.rgb
Binary file not shown.

0 comments on commit 1262255

Please sign in to comment.