Skip to content

Commit

Permalink
fix microtvm test script
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Aug 11, 2021
1 parent f6bb1e1 commit a6dad01
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 18 deletions.
12 changes: 3 additions & 9 deletions tests/micro/zephyr/test_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def test_relay(temp_dir, platform, west_cmd, tvm_debug):
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = tvm.relay.build(func, target=target)

with _make_session(
temp_dir, zephyr_board, west_cmd, mod, build_config
) as session:
with _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) as session:
graph_mod = tvm.micro.create_local_graph_executor(
mod.get_graph_json(), session.get_system_lib(), session.device
)
Expand Down Expand Up @@ -255,9 +253,7 @@ def test_onnx(temp_dir, platform, west_cmd, tvm_debug):
lowered = relay.build(relay_mod, target, params=params)
graph = lowered.get_graph_json()

with _make_session(
temp_dir, zephyr_board, west_cmd, lowered, build_config
) as session:
with _make_session(temp_dir, zephyr_board, west_cmd, lowered, build_config) as session:
graph_mod = tvm.micro.create_local_graph_executor(
graph, session.get_system_lib(), session.device
)
Expand All @@ -284,9 +280,7 @@ def check_result(
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = tvm.relay.build(relay_mod, target=target)

with _make_session(
temp_dir, zephyr_board, west_cmd, mod, build_config
) as session:
with _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config) as session:
rt_mod = tvm.micro.create_local_graph_executor(
mod.get_graph_json(), session.get_system_lib(), session.device
)
Expand Down
14 changes: 6 additions & 8 deletions tests/micro/zephyr/test_zephyr_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
PLATFORMS = conftest.PLATFORMS


def _build_project(
temp_dir, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None
):
def _build_project(temp_dir, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None):
template_project_dir = (
pathlib.Path(__file__).parent
/ ".."
Expand All @@ -69,7 +67,7 @@ def _build_project(
"extra_files_tar": extra_files_tar,
"project_type": "aot_demo",
"west_cmd": west_cmd,
"verbose": 0,
"verbose": bool(build_config.get("debug")),
"zephyr_board": zephyr_board,
},
)
Expand Down Expand Up @@ -139,7 +137,7 @@ def _get_message(fd, expr: str):


@tvm.testing.requires_micro
def test_tflite(temp_dir, platform, west_cmd, skip_build, tvm_debug):
def test_tflite(temp_dir, platform, west_cmd, tvm_debug):
"""Testing a TFLite model."""

if platform not in ["host", "mps2_an521", "nrf5340dk", "stm32l4r5zi_nucleo", "zynq_mp_r5"]:
Expand All @@ -148,7 +146,7 @@ def test_tflite(temp_dir, platform, west_cmd, skip_build, tvm_debug):
model, zephyr_board = PLATFORMS[platform]
input_shape = (1, 32, 32, 3)
output_shape = (1, 10)
build_config = {"skip_build": skip_build, "debug": tvm_debug}
build_config = {"debug": tvm_debug}

model_url = "https://github.com/eembc/ulpmark-ml/raw/fc1499c7cc83681a02820d5ddf5d97fe75d4f663/base_models/ic01/ic01_fp32.tflite"
model_path = download_testdata(model_url, "ic01_fp32.tflite", module="model")
Expand Down Expand Up @@ -222,13 +220,13 @@ def test_tflite(temp_dir, platform, west_cmd, skip_build, tvm_debug):


@tvm.testing.requires_micro
def test_qemu_make_fail(temp_dir, platform, west_cmd, skip_build, tvm_debug):
def test_qemu_make_fail(temp_dir, platform, west_cmd, tvm_debug):
"""Testing QEMU make fail."""
if platform not in ["host", "mps2_an521"]:
pytest.skip(msg="Only for QEMU targets.")

model, zephyr_board = PLATFORMS[platform]
build_config = {"skip_build": skip_build, "debug": tvm_debug}
build_config = {"debug": tvm_debug}
shape = (10,)
dtype = "float32"

Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/task_python_microtvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ set -x # NOTE(areusch): Adding to diagnose flaky timeouts
source tests/scripts/setup-pytest-env.sh

make cython3
run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=host
run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=qemu_x86
run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=mps2_an521

0 comments on commit a6dad01

Please sign in to comment.