diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index bcb2bddf2cab8..f06171d1b8e1b 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -20,6 +20,7 @@ import pytest import tvm.target.target +from tvm import micro, relay # The models that should pass this configuration. Maps a short, identifying platform string to # (model, zephyr_board). @@ -121,3 +122,42 @@ def make_workspace_dir(test_name, platform): t = tvm.contrib.utils.tempdir(board_workspace) # time.sleep(200) return t + + +def make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + this_dir = pathlib.Path(__file__).parent + model, arduino_board = PLATFORMS[platform] + build_config = {"debug": tvm_debug} + + with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f: + tflite_model_buf = f.read() + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + target = tvm.target.target.micro( + model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"] + ) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = relay.build(mod, target, params=params) + + return tvm.micro.generate_project( + str(TEMPLATE_PROJECT_DIR), + mod, + workspace_dir / "project", + { + "arduino_board": arduino_board, + "arduino_cli_cmd": arduino_cli_cmd, + "project_type": "example_project", + "verbose": bool(build_config.get("debug")), + }, + ) diff --git a/tests/micro/arduino/test_arduino_error_detection.py b/tests/micro/arduino/test_arduino_error_detection.py new file mode 100644 index 0000000000000..2c3873d571f93 --- /dev/null +++ b/tests/micro/arduino/test_arduino_error_detection.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pathlib +import re +import sys + +import pytest + +import conftest +from tvm.micro.project_api.server import ServerError + + +@pytest.fixture(scope="function") +def workspace_dir(request, platform): + return conftest.make_workspace_dir("arduino_error_detection", platform) + + +@pytest.fixture(scope="function") +def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir): + return conftest.make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir) + + +def test_blank_project_compiles(workspace_dir, project): + project.build() + + +# Add a bug (an extra curly brace) and make sure the project doesn't compile +def test_bugged_project_compile_fails(workspace_dir, project): + with open(workspace_dir / "project" / "project.ino", "a") as main_file: + main_file.write("}\n") + with pytest.raises(ServerError): + project.build() + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/micro/arduino/test_arduino_workflow.py b/tests/micro/arduino/test_arduino_workflow.py index 4f16c486142c2..dbd67e17d2f9a 100644 --- a/tests/micro/arduino/test_arduino_workflow.py +++ b/tests/micro/arduino/test_arduino_workflow.py @@ -19,14 +19,10 @@ import pathlib import shutil import sys -import tempfile import pytest import conftest -import tvm -from tvm import micro, relay -from tvm.micro.project_api.server import ServerError """ This unit test simulates a simple user workflow, where we: @@ -51,73 +47,10 @@ def project_dir(workspace_dir): return workspace_dir / "project" -# Saves the Arduino project's state, runs the test, then resets it -@pytest.fixture(scope="function") -def does_not_affect_state(project_dir): - with tempfile.TemporaryDirectory() as temp_dir: - prev_project_state = pathlib.Path(temp_dir) - shutil.copytree(project_dir, prev_project_state / "project") - yield - - # We can't delete project_dir or it'll mess up the Arduino CLI working directory - # Instead, delete everything in project_dir, and then copy over the files - for item in project_dir.iterdir(): - if item.is_dir(): - shutil.rmtree(item) - else: - item.unlink() # Delete file - # Once we upgrade to Python 3.7, this can be replaced with - # shutil.copytree(dirs_exist_ok=True) - for item in (prev_project_state / "project").iterdir(): - if item.is_dir(): - shutil.copytree(item, project_dir / item.name) - else: - shutil.copy2(item, project_dir) - - -def _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config): - return tvm.micro.generate_project( - str(conftest.TEMPLATE_PROJECT_DIR), - mod, - workspace_dir / "project", - { - "arduino_board": arduino_board, - "arduino_cli_cmd": arduino_cli_cmd, - "project_type": "example_project", - "verbose": bool(build_config.get("debug")), - }, - ) - - # We MUST pass workspace_dir, not project_dir, or the workspace will be dereferenced too soon @pytest.fixture(scope="module") def project(platform, arduino_cli_cmd, tvm_debug, workspace_dir): - this_dir = pathlib.Path(__file__).parent - model, arduino_board = conftest.PLATFORMS[platform] - build_config = {"debug": tvm_debug} - - with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f: - tflite_model_buf = f.read() - - # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 - try: - import tflite.Model - - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - except AttributeError: - import tflite - - tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) - - mod, params = relay.frontend.from_tflite(tflite_model) - target = tvm.target.target.micro( - model, options=["--link-params=1", "--unpacked-api=1", "--executor=aot"] - ) - - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = relay.build(mod, target, params=params) - - return _generate_project(arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config) + return conftest.make_kws_project(platform, arduino_cli_cmd, tvm_debug, workspace_dir) def _get_directory_elements(directory): @@ -161,20 +94,6 @@ def test_import_rerouting(project_dir, project): assert "include/tvm/runtime/crt/platform.h" in c_backend_api_c -@pytest.mark.usefixtures("does_not_affect_state") -def test_blank_project_compiles(project): - project.build() - - -# Add a bug (an extra curly brace) and make sure the project doesn't compile -@pytest.mark.usefixtures("does_not_affect_state") -def test_bugged_project_compile_fails(project_dir, project): - with open(project_dir / "project.ino", "a") as main_file: - main_file.write("}\n") - with pytest.raises(ServerError): - project.build() - - # Build on top of the generated project by replacing the # top-level .ino fileand adding data input files, much # like a user would