Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
Break error detection tests into separate file

Address comments from Mousius

Re-add necessary fixture
  • Loading branch information
guberti committed Sep 2, 2021
1 parent 3a36816 commit e873436
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 85 deletions.
40 changes: 40 additions & 0 deletions tests/micro/arduino/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import pytest
import tvm.target.target
from tvm import micro, relay

# The models that should pass this configuration. Maps a short, identifying platform string to
# (model, zephyr_board).
Expand Down Expand Up @@ -121,3 +122,42 @@ def make_workspace_dir(test_name, platform):
t = tvm.contrib.utils.tempdir(board_workspace)
# time.sleep(200)
return t


def make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir):
this_dir = pathlib.Path(__file__).parent
model, arduino_board = PLATFORMS[platform]
build_config = {"debug": tvm_debug}

with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f:
tflite_model_buf = f.read()

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

mod, params = relay.frontend.from_tflite(tflite_model)
target = tvm.target.target.micro(
model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"]
)

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = relay.build(mod, target, params=params)

return tvm.micro.generate_project(
str(TEMPLATE_PROJECT_DIR),
mod,
workspace_dir / "project",
{
"arduino_board": arduino_board,
"arduino_cli_cmd": arduino_cli_cmd,
"project_type": "example_project",
"verbose": bool(build_config.get("debug")),
},
)
52 changes: 52 additions & 0 deletions tests/micro/arduino/test_arduino_error_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pathlib
import re
import sys

import pytest

import conftest
from tvm.micro.project_api.server import ServerError


# A new project and workspace dir is created for EVERY test
@pytest.fixture
def workspace_dir(request, platform):
return conftest.make_workspace_dir("arduino_error_detection", platform)


@pytest.fixture(scope="function")
def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir):
return conftest.make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir)


def test_blank_project_compiles(workspace_dir, project):
project.build()


# Add a bug (an extra curly brace) and make sure the project doesn't compile
def test_bugged_project_compile_fails(workspace_dir, project):
with open(workspace_dir / "project" / "project.ino", "a") as main_file:
main_file.write("}\n")
with pytest.raises(ServerError):
project.build()


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
4 changes: 2 additions & 2 deletions tests/micro/arduino/test_arduino_rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import conftest


# We'll make a new workspace for each test
@pytest.fixture(scope="function")
# # A new project and workspace dir is created for EVERY test
@pytest.fixture
def workspace_dir(platform):
return conftest.make_workspace_dir("arduino_rpc_server", platform)

Expand Down
86 changes: 3 additions & 83 deletions tests/micro/arduino/test_arduino_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@
import pathlib
import shutil
import sys
import tempfile

import pytest

import conftest
import tvm
from tvm import micro, relay
from tvm.micro.project_api.server import ServerError

"""
This unit test simulates a simple user workflow, where we:
Expand All @@ -40,7 +36,8 @@
"""


# Since these tests are sequential, we'll use the same project for all tests
# Since these tests are sequential, we'll use the same project/workspace
# directory for all tests in this file
@pytest.fixture(scope="module")
def workspace_dir(request, platform):
return conftest.make_workspace_dir("arduino_workflow", platform)
Expand All @@ -51,73 +48,10 @@ def project_dir(workspace_dir):
return workspace_dir / "project"


# Saves the Arduino project's state, runs the test, then resets it
@pytest.fixture(scope="function")
def does_not_affect_state(project_dir):
with tempfile.TemporaryDirectory() as temp_dir:
prev_project_state = pathlib.Path(temp_dir)
shutil.copytree(project_dir, prev_project_state / "project")
yield

# We can't delete project_dir or it'll mess up the Arduino CLI working directory
# Instead, delete everything in project_dir, and then copy over the files
for item in project_dir.iterdir():
if item.is_dir():
shutil.rmtree(item)
else:
item.unlink() # Delete file
# Once we upgrade to Python 3.7, this can be replaced with
# shutil.copytree(dirs_exist_ok=True)
for item in (prev_project_state / "project").iterdir():
if item.is_dir():
shutil.copytree(item, project_dir / item.name)
else:
shutil.copy2(item, project_dir)


def _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config):
return tvm.micro.generate_project(
str(conftest.TEMPLATE_PROJECT_DIR),
mod,
workspace_dir / "project",
{
"arduino_board": arduino_board,
"arduino_cli_cmd": arduino_cli_cmd,
"project_type": "example_project",
"verbose": bool(build_config.get("debug")),
},
)


# We MUST pass workspace_dir, not project_dir, or the workspace will be dereferenced too soon
@pytest.fixture(scope="module")
def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir):
this_dir = pathlib.Path(__file__).parent
model, arduino_board = conftest.PLATFORMS[platform]
build_config = {"debug": tvm_debug}

with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f:
tflite_model_buf = f.read()

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

mod, params = relay.frontend.from_tflite(tflite_model)
target = tvm.target.target.micro(
model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"]
)

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = relay.build(mod, target, params=params)

return _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config)
return conftest.make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir)


def _get_directory_elements(directory):
Expand Down Expand Up @@ -161,20 +95,6 @@ def test_import_rerouting(project_dir, project):
assert "include/tvm/runtime/crt/platform.h" in c_backend_api_c


@pytest.mark.usefixtures("does_not_affect_state")
def test_blank_project_compiles(project):
project.build()


# Add a bug (an extra curly brace) and make sure the project doesn't compile
@pytest.mark.usefixtures("does_not_affect_state")
def test_bugged_project_compile_fails(project_dir, project):
with open(project_dir / "project.ino", "a") as main_file:
main_file.write("}\n")
with pytest.raises(ServerError):
project.build()


# Build on top of the generated project by replacing the
# top-level .ino fileand adding data input files, much
# like a user would
Expand Down

0 comments on commit e873436

Please sign in to comment.