Skip to content

Commit

Permalink
prtest:full fix running WASI-NN ONNX tests across arch os
Browse files Browse the repository at this point in the history
Signed-off-by: David Justice <david@devigned.com>
  • Loading branch information
devigned committed Mar 13, 2024
1 parent af7c140 commit 2416343
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
65 changes: 38 additions & 27 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
//! - that WinML is available
//! - that some ML model artifacts can be downloaded and cached.
#[allow(unused_imports)]
use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command, sync::Mutex};
#[cfg(all(feature = "winml", target_os = "windows"))]

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
use windows::AI::MachineLearning::{LearningModelDevice, LearningModelDeviceKind};

/// Return the directory in which the test artifacts are stored.
Expand Down Expand Up @@ -37,25 +39,39 @@ macro_rules! check_test {

/// Return `Ok` if all checks pass.
pub fn check() -> Result<()> {
#[cfg(feature = "openvino")]
#[cfg(all(
feature = "openvino",
target_arch = "x86_64",
any(target_os = "linux", target_os = "windows")
))]
{
check_openvino_is_installed()?;
check_openvino_artifacts_are_available()?;
}
#[cfg(feature = "openvino")]
check_openvino_artifacts_are_available()?;

#[cfg(feature = "onnx")]
{
check_onnx_artifacts_are_available()?;
}
#[cfg(all(feature = "winml", target_os = "windows"))]
check_onnx_artifacts_are_available()?;

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
{
check_winml_is_available()?;
check_winml_artifacts_are_available()?;
}
Ok(())
}

/// Protect `check_openvino_artifacts_are_available` from concurrent access;
/// when running tests in parallel, we want to avoid two threads attempting to
/// create the same directory or download the same file.
static ARTIFACTS: Mutex<()> = Mutex::new(());

/// Return `Ok` if we find a working OpenVINO installation.
#[cfg(feature = "openvino")]
#[cfg(all(
feature = "openvino",
target_arch = "x86_64",
any(target_os = "linux", target_os = "windows")
))]
fn check_openvino_is_installed() -> Result<()> {
match std::panic::catch_unwind(|| println!("> found openvino version: {}", openvino::version()))
{
Expand All @@ -64,24 +80,6 @@ fn check_openvino_is_installed() -> Result<()> {
}
}

#[cfg(all(feature = "winml", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
"> WinML learning device is available: {:?}",
LearningModelDevice::Create(LearningModelDeviceKind::Default)
)
}) {
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("WinML learning device is not available: {:?}", e)),
}
}

/// Protect `check_openvino_artifacts_are_available` from concurrent access;
/// when running tests in parallel, we want to avoid two threads attempting to
/// create the same directory or download the same file.
static ARTIFACTS: Mutex<()> = Mutex::new(());

/// Return `Ok` if we find the cached MobileNet test artifacts; this will
/// download the artifacts if necessary.
#[cfg(feature = "openvino")]
Expand Down Expand Up @@ -110,6 +108,19 @@ fn check_openvino_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
"> WinML learning device is available: {:?}",
LearningModelDevice::Create(LearningModelDeviceKind::Default)
)
}) {
Ok(_) => Ok(()),
Err(e) => Err(anyhow!("WinML learning device is not available: {:?}", e)),
}
}

#[cfg(feature = "onnx")]
fn check_onnx_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
Expand Down Expand Up @@ -141,7 +152,7 @@ fn check_onnx_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(all(feature = "winml", target_os = "windows"))]
#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
let artifacts_dir = artifacts_dir();
Expand Down
7 changes: 5 additions & 2 deletions crates/wasi-nn/tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ fn nn_image_classification_winml() {
)]
#[test]
fn nn_image_classification_onnx() {
let backend = Backend::from(backend::onnxruntime::OnnxBackend::default());
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false).unwrap()
#[cfg(feature = "onnx")]
{
let backend = Backend::from(backend::onnxruntime::OnnxBackend::default());
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, false).unwrap()
}
}

0 comments on commit 2416343

Please sign in to comment.