Skip to content

Commit

Permalink
Break error detection tests into separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
guberti committed Aug 30, 2021
1 parent 3a36816 commit 3a8ea4c
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 82 deletions.
40 changes: 40 additions & 0 deletions tests/micro/arduino/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytest
import tvm.target.target
from tvm import micro, relay

# The models that should pass this configuration. Maps a short, identifying platform string to
# (model, zephyr_board).
Expand Down Expand Up @@ -121,3 +122,42 @@ def make_workspace_dir(test_name, platform):
t = tvm.contrib.utils.tempdir(board_workspace)
# time.sleep(200)
return t


def make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir):
this_dir = pathlib.Path(__file__).parent
model, arduino_board = PLATFORMS[platform]
build_config = {"debug": tvm_debug}

with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f:
tflite_model_buf = f.read()

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

mod, params = relay.frontend.from_tflite(tflite_model)
target = tvm.target.target.micro(
model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"]
)

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = relay.build(mod, target, params=params)

return tvm.micro.generate_project(
str(TEMPLATE_PROJECT_DIR),
mod,
workspace_dir / "project",
{
"arduino_board": arduino_board,
"arduino_cli_cmd": arduino_cli_cmd,
"project_type": "example_project",
"verbose": bool(build_config.get("debug")),
},
)
51 changes: 51 additions & 0 deletions tests/micro/arduino/test_arduino_error_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pathlib
import re
import sys

import pytest

import conftest
from tvm.micro.project_api.server import ServerError


@pytest.fixture(scope="function")
def workspace_dir(request, platform):
return conftest.make_workspace_dir("arduino_error_detection", platform)


@pytest.fixture(scope="function")
def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir):
return conftest.make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir)


def test_blank_project_compiles(workspace_dir, project):
project.build()


# Add a bug (an extra curly brace) and make sure the project doesn't compile
def test_bugged_project_compile_fails(workspace_dir, project):
with open(workspace_dir / "project" / "project.ino", "a") as main_file:
main_file.write("}\n")
with pytest.raises(ServerError):
project.build()


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
83 changes: 1 addition & 82 deletions tests/micro/arduino/test_arduino_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@
import pathlib
import shutil
import sys
import tempfile

import pytest

import conftest
import tvm
from tvm import micro, relay
from tvm.micro.project_api.server import ServerError

"""
This unit test simulates a simple user workflow, where we:
Expand All @@ -51,73 +47,10 @@ def project_dir(workspace_dir):
return workspace_dir / "project"


# Saves the Arduino project's state, runs the test, then resets it
@pytest.fixture(scope="function")
def does_not_affect_state(project_dir):
with tempfile.TemporaryDirectory() as temp_dir:
prev_project_state = pathlib.Path(temp_dir)
shutil.copytree(project_dir, prev_project_state / "project")
yield

# We can't delete project_dir or it'll mess up the Arduino CLI working directory
# Instead, delete everything in project_dir, and then copy over the files
for item in project_dir.iterdir():
if item.is_dir():
shutil.rmtree(item)
else:
item.unlink() # Delete file
# Once we upgrade to Python 3.7, this can be replaced with
# shutil.copytree(dirs_exist_ok=True)
for item in (prev_project_state / "project").iterdir():
if item.is_dir():
shutil.copytree(item, project_dir / item.name)
else:
shutil.copy2(item, project_dir)


def _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config):
return tvm.micro.generate_project(
str(conftest.TEMPLATE_PROJECT_DIR),
mod,
workspace_dir / "project",
{
"arduino_board": arduino_board,
"arduino_cli_cmd": arduino_cli_cmd,
"project_type": "example_project",
"verbose": bool(build_config.get("debug")),
},
)


# We MUST pass workspace_dir, not project_dir, or the workspace will be dereferenced too soon
@pytest.fixture(scope="module")
def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir):
this_dir = pathlib.Path(__file__).parent
model, arduino_board = conftest.PLATFORMS[platform]
build_config = {"debug": tvm_debug}

with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f:
tflite_model_buf = f.read()

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

mod, params = relay.frontend.from_tflite(tflite_model)
target = tvm.target.target.micro(
model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"]
)

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = relay.build(mod, target, params=params)

return _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config)
return conftest.make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir)


def _get_directory_elements(directory):
Expand Down Expand Up @@ -161,20 +94,6 @@ def test_import_rerouting(project_dir, project):
assert "include/tvm/runtime/crt/platform.h" in c_backend_api_c


@pytest.mark.usefixtures("does_not_affect_state")
def test_blank_project_compiles(project):
project.build()


# Add a bug (an extra curly brace) and make sure the project doesn't compile
@pytest.mark.usefixtures("does_not_affect_state")
def test_bugged_project_compile_fails(project_dir, project):
with open(project_dir / "project.ino", "a") as main_file:
main_file.write("}\n")
with pytest.raises(ServerError):
project.build()


# Build on top of the generated project by replacing the
# top-level .ino fileand adding data input files, much
# like a user would
Expand Down

0 comments on commit 3a8ea4c

Please sign in to comment.