diff --git a/apps/microtvm/arduino/example_project/project.ino b/apps/microtvm/arduino/example_project/project.ino index 74b596245624..5f5683161e0a 100644 --- a/apps/microtvm/arduino/example_project/project.ino +++ b/apps/microtvm/arduino/example_project/project.ino @@ -21,9 +21,10 @@ void setup() { TVMInitialize(); - //TVMExecute(input_data, output_data); + // If desired, initialize the RNG with random noise + // randomSeed(analogRead(0)); } void loop() { - // put your main code here, to run repeatedly: + //TVMExecute(input_data, output_data); } diff --git a/apps/microtvm/arduino/example_project/src/model.c b/apps/microtvm/arduino/example_project/src/model.c index 7acb06fa89af..fc8b5836314b 100644 --- a/apps/microtvm/arduino/example_project/src/model.c +++ b/apps/microtvm/arduino/example_project/src/model.c @@ -17,9 +17,6 @@ * under the License. */ -#ifndef TVM_IMPLEMENTATION_ARDUINO -#define TVM_IMPLEMENTATION_ARDUINO - #include "model.h" #include "Arduino.h" @@ -33,6 +30,7 @@ tvm_workspace_t app_workspace; // Blink code for debugging purposes void TVMPlatformAbort(tvm_crt_error_t error) { + TVMLogf("TVMPlatformAbort: %08x\n", error); for (;;) { #ifdef LED_BUILTIN digitalWrite(LED_BUILTIN, HIGH); @@ -57,10 +55,7 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { return StackMemoryManager_Free(&app_workspace, ptr); } -unsigned long g_utvm_start_time; - -#define MILLIS_TIL_EXPIRY 200 - +unsigned long g_utvm_start_time_micros; int g_utvm_timer_running = 0; tvm_crt_error_t TVMPlatformTimerStart() { @@ -68,7 +63,7 @@ tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorPlatformTimerBadState; } g_utvm_timer_running = 1; - g_utvm_start_time = micros(); + g_utvm_start_time_micros = micros(); return kTvmErrorNoError; } @@ -77,7 +72,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { return kTvmErrorPlatformTimerBadState; } g_utvm_timer_running = 0; - unsigned long g_utvm_stop_time = micros() - g_utvm_start_time; + unsigned long g_utvm_stop_time = micros() - g_utvm_start_time_micros; *elapsed_time_seconds = ((double)g_utvm_stop_time) / 1e6; return kTvmErrorNoError; } @@ -97,5 +92,3 @@ void TVMExecute(void* input_data, void* output_data) { TVMPlatformAbort(kTvmErrorPlatformCheckFailure); } } - -#endif diff --git a/apps/microtvm/arduino/example_project/src/model.h b/apps/microtvm/arduino/example_project/src/model.h index 489adff0da53..7381c97e9b3f 100644 --- a/apps/microtvm/arduino/example_project/src/model.h +++ b/apps/microtvm/arduino/example_project/src/model.h @@ -17,9 +17,6 @@ * under the License. */ -#ifndef IMPLEMENTATION_H_ -#define IMPLEMENTATION_H_ - #define WORKSPACE_SIZE $workspace_size_bytes #ifdef __cplusplus @@ -28,11 +25,16 @@ extern "C" { void TVMInitialize(); -// TODO template these void* values once MLF format has input and output data +/* TODO template this function signature with the input and output + * data types and sizes. For example: + * + * void TVMExecute(uint8_t input_data[9216], uint8_t output_data[3]); + * + * Note this can only be done once MLF has JSON metadata describing + * inputs and outputs. + */ void TVMExecute(void* input_data, void* output_data); #ifdef __cplusplus } // extern "C" #endif - -#endif // IMPLEMENTATION_H_ diff --git a/apps/microtvm/arduino/host_driven/project.ino b/apps/microtvm/arduino/host_driven/project.ino index 34537d4e205f..c1b7f3870400 100644 --- a/apps/microtvm/arduino/host_driven/project.ino +++ b/apps/microtvm/arduino/host_driven/project.ino @@ -19,7 +19,6 @@ #include "src/standalone_crt/include/tvm/runtime/crt/microtvm_rpc_server.h" #include "src/standalone_crt/include/tvm/runtime/crt/logging.h" -#include "src/model.h" microtvm_rpc_server_t server; // Called by TVM to write serial data to the UART. @@ -32,22 +31,23 @@ void setup() { server = MicroTVMRpcServerInit(write_serial, NULL); TVMLogf("microTVM Arduino runtime - running"); Serial.begin(115200); + + // If desired, initialize the RNG with random noise + // randomSeed(analogRead(0)); } void loop() { - int to_read = Serial.available(); + // Read at most 128 bytes at a time to prevent stack blowup + int to_read = min(Serial.available(), 128); + uint8_t data[to_read]; - size_t bytes_read = Serial.readBytes(data, to_read); + size_t bytes_remaining = Serial.readBytes(data, to_read); uint8_t* arr_ptr = data; - uint8_t** data_ptr = &arr_ptr; - if (bytes_read > 0) { - size_t bytes_remaining = bytes_read; - while (bytes_remaining > 0) { - // Pass the received bytes to the RPC server. - tvm_crt_error_t err = MicroTVMRpcServerLoop(server, data_ptr, &bytes_remaining); - if (err != kTvmErrorNoError && err != kTvmErrorFramingShortPacket) { - TVMPlatformAbort(err); - } + while (bytes_remaining > 0) { + // Pass the received bytes to the RPC server. + tvm_crt_error_t err = MicroTVMRpcServerLoop(server, &arr_ptr, &bytes_remaining); + if (err != kTvmErrorNoError && err != kTvmErrorFramingShortPacket) { + TVMPlatformAbort(err); } } } diff --git a/apps/microtvm/arduino/host_driven/src/model.h b/apps/microtvm/arduino/host_driven/src/model.h deleted file mode 100644 index edc83e5123f8..000000000000 --- a/apps/microtvm/arduino/host_driven/src/model.h +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef IMPLEMENTATION_H_ -#define IMPLEMENTATION_H_ - -#define WORKSPACE_SIZE 65535 - -#endif // IMPLEMENTATION_H_ diff --git a/apps/microtvm/arduino/host_driven/src/model.c b/apps/microtvm/arduino/host_driven/src/model_support.c similarity index 80% rename from apps/microtvm/arduino/host_driven/src/model.c rename to apps/microtvm/arduino/host_driven/src/model_support.c index 5b87deb526d5..6fe36099227f 100644 --- a/apps/microtvm/arduino/host_driven/src/model.c +++ b/apps/microtvm/arduino/host_driven/src/model_support.c @@ -17,29 +17,13 @@ * under the License. */ -#ifndef TVM_IMPLEMENTATION_ARDUINO -#define TVM_IMPLEMENTATION_ARDUINO - -#include "model.h" - -#include "Arduino.h" -#include "standalone_crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h" #include "stdarg.h" +#include "standalone_crt/include/tvm/runtime/crt/internal/aot_executor/aot_executor.h" // Blink code for debugging purposes void TVMPlatformAbort(tvm_crt_error_t error) { - for (;;) { -#ifdef LED_BUILTIN - digitalWrite(LED_BUILTIN, HIGH); - delay(250); - digitalWrite(LED_BUILTIN, LOW); - delay(250); - digitalWrite(LED_BUILTIN, HIGH); - delay(250); - digitalWrite(LED_BUILTIN, LOW); - delay(750); -#endif - } + TVMLogf("TVMPlatformAbort: %08x\n", error); + for (;;); } size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt, @@ -60,10 +44,7 @@ tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) { return kTvmErrorNoError; } -unsigned long g_utvm_start_time; - -#define MILLIS_TIL_EXPIRY 200 - +unsigned long g_utvm_start_time_micros; int g_utvm_timer_running = 0; tvm_crt_error_t TVMPlatformTimerStart() { @@ -71,7 +52,7 @@ tvm_crt_error_t TVMPlatformTimerStart() { return kTvmErrorPlatformTimerBadState; } g_utvm_timer_running = 1; - g_utvm_start_time = micros(); + g_utvm_start_time_micros = micros(); return kTvmErrorNoError; } @@ -80,7 +61,7 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) { return kTvmErrorPlatformTimerBadState; } g_utvm_timer_running = 0; - unsigned long g_utvm_stop_time = micros() - g_utvm_start_time; + unsigned long g_utvm_stop_time = micros() - g_utvm_start_time_micros; *elapsed_time_seconds = ((double)g_utvm_stop_time) / 1e6; return kTvmErrorNoError; } @@ -91,5 +72,3 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { } return kTvmErrorNoError; } - -#endif diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 7297bab6405d..a348d37288cc 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -28,13 +28,12 @@ import subprocess import sys import tarfile -from string import Template import tempfile import time +from string import Template import serial import serial.tools.list_ports - from tvm.micro.project_api import server MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar" @@ -45,20 +44,20 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() -class InvalidPortException(Exception): - """Raised when the given port could not be opened""" - - -class SketchUploadException(Exception): - """Raised when a sketch cannot be uploaded for an unknown reason.""" - - class BoardAutodetectFailed(Exception): """Raised when no attached hardware is found matching the requested board""" +# Data structure to hold the information microtvm_api_server.py needs +# to communicate with each of these boards. Currently just holds the +# components of each board's FQBN, but might be extended in the future +# to include the SRAM, PSRAM, flash, etc. on each board. BOARD_PROPERTIES = { - "due": {"package": "arduino", "architecture": "sam", "board": "arduino_due_x"}, + "due": { + "package": "arduino", + "architecture": "sam", + "board": "arduino_due_x", + }, # Due to the way the Feather S2 bootloader works, compilation # behaves fine but uploads cannot be done automatically "feathers2": { @@ -108,7 +107,7 @@ class BoardAutodetectFailed(Exception): server.ProjectOption("arduino_cli_cmd", help="Path to the arduino-cli tool."), server.ProjectOption("port", help="Port to use for connecting to hardware"), server.ProjectOption( - "example_project", + "project_type", help="Type of project to generate.", choices=tuple(PROJECT_TYPES), ), @@ -201,6 +200,10 @@ def _template_model_header(self, source_dir, metadata): with open(source_dir / "model.h", "r") as f: model_h_template = Template(f.read()) + # The structure of the "memory" key depends on the style - + # only style="full-model" works with AOT, so we'll check that + assert metadata["style"] == "full-model" + template_values = { "workspace_size_bytes": metadata["memory"]["functions"]["main"][0][ "workspace_size_bytes" @@ -222,47 +225,20 @@ def _change_cpp_file_extensions(self, source_dir): for filename in source_dir.rglob(f"*.inc"): filename.rename(filename.with_suffix(".h")) - def _process_autogenerated_inc_files(self, source_dir): - for filename in source_dir.rglob(f"*.inc"): - # Individual file fixes - if filename.stem == "gentab_ccitt": - with open(filename, "r+") as f: - content = f.read() - f.seek(0, 0) - f.write('#include "inttypes.h"\n' + content) - - filename.rename(filename.with_suffix(".c")) - - POSSIBLE_BASE_PATHS = ["src/standalone_crt/include/", "src/standalone_crt/crt_config/"] - - def _find_modified_include_path(self, project_dir, file_path, import_path): - # If the import is for a .inc file we renamed to .c earlier, fix it - if import_path.endswith(self.CPP_FILE_EXTENSION_SYNONYMS): - import_path = re.sub(r"\.[a-z]+$", ".cpp", import_path) + """Arduino only supports includes relative to the top-level project, so this + finds each time we #include a file and changes the path to be relative to the + top-level project.ino file. For example, the line: - if import_path.endswith(".inc"): - import_path = re.sub(r"\.[a-z]+$", ".h", import_path) + #include - # If the import already works, don't modify it - if (file_path.parents[0] / import_path).exists(): - return import_path + Might be changed to (depending on the source file's location): - relative_path = file_path.relative_to(project_dir) - up_dirs_path = "../" * str(relative_path).count("/") - - for base_path in self.POSSIBLE_BASE_PATHS: - full_potential_path = project_dir / base_path / import_path - if full_potential_path.exists(): - new_include = up_dirs_path + base_path + import_path - return new_include + #include "../../../../include/tvm/runtime/crt/platform.h" - # If we can't find the file, just leave it untouched - # It's probably a standard C/C++ header - return import_path + We also need to leave standard library includes as-is. + """ - # Arduino only supports imports relative to the top-level project, - # so we need to adjust each import to meet this convention - def _convert_imports(self, project_dir, source_dir): + def _convert_includes(self, project_dir, source_dir): for ext in ("c", "h", "cpp"): for filename in source_dir.rglob(f"*.{ext}"): with filename.open() as file: @@ -282,6 +258,55 @@ def _convert_imports(self, project_dir, source_dir): with filename.open("w") as file: file.writelines(lines) + # Most of the files we used to be able to point to directly are under "src/standalone_crt/include/". + # Howver, crt_config.h lives under "src/standalone_crt/crt_config/", and more exceptions might + # be added in the future. + POSSIBLE_BASE_PATHS = ["src/standalone_crt/include/", "src/standalone_crt/crt_config/"] + + """Takes a single #include path, and returns the new location + it should point to (as described above). For example, one of the + includes for "src/standalone_crt/src/runtime/crt/common/ndarray.c" is: + + #include + + For that line, _convert_includes might call _find_modified_include_path + with the arguments: + + project_dir = "/path/to/project/dir" + file_path = "/path/to/project/dir/src/standalone_crt/src/runtime/crt/common/ndarray.c" + include_path = "tvm/runtime/crt/platform.h" + + Given these arguments, _find_modified_include_path should return: + + "../../../../../../src/standalone_crt/include/tvm/runtime/crt/platform.h" + + See unit test in ./tests/test_arduino_microtvm_api_server.py + """ + + def _find_modified_include_path(self, project_dir, file_path, include_path): + if include_path.endswith(".inc"): + include_path = re.sub(r"\.[a-z]+$", ".h", include_path) + + # Change includes referencing .cc and .cxx files to point to the renamed .cpp file + if include_path.endswith(self.CPP_FILE_EXTENSION_SYNONYMS): + include_path = re.sub(r"\.[a-z]+$", ".cpp", include_path) + + # If the include already works, don't modify it + if (file_path.parents[0] / include_path).exists(): + return include_path + + relative_path = file_path.relative_to(project_dir) + up_dirs_path = "../" * str(relative_path).count("/") + + for base_path in self.POSSIBLE_BASE_PATHS: + full_potential_path = project_dir / base_path / include_path + if full_potential_path.exists(): + return up_dirs_path + base_path + include_path + + # If we can't find the file, just leave it untouched + # It's probably a standard C/C++ header + return include_path + def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Reference key directories with pathlib project_dir = pathlib.Path(project_dir) @@ -291,9 +316,8 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Copies files from the template folder to project_dir. model.h is copied here, # but will also need to be templated later. - if IS_TEMPLATE: - shutil.copy2(API_SERVER_DIR / "microtvm_api_server.py", project_dir) - self._copy_project_files(API_SERVER_DIR, project_dir, options["project_type"]) + shutil.copy2(API_SERVER_DIR / "microtvm_api_server.py", project_dir) + self._copy_project_files(API_SERVER_DIR, project_dir, options["project_type"]) # Copy standalone_crt into src folder self._copy_standalone_crt(source_dir, standalone_crt_dir) @@ -309,8 +333,8 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec self._change_cpp_file_extensions(source_dir) - # Recursively change imports - self._convert_imports(project_dir, source_dir) + # Recursively change includes + self._convert_includes(project_dir, source_dir) def _get_fqbn(self, options): o = BOARD_PROPERTIES[options["arduino_board"]] @@ -318,7 +342,6 @@ def _get_fqbn(self, options): def build(self, options): BUILD_DIR.mkdir() - print(BUILD_DIR) compile_cmd = [ options["arduino_cli_cmd"], @@ -334,28 +357,42 @@ def build(self, options): compile_cmd.append("--verbose") # Specify project to compile - output = subprocess.check_call(compile_cmd) - assert output == 0 + subprocess.run(compile_cmd) + + """We run the command `arduino-cli board list`, which produces + outputs of the form: - # We run the command `arduino-cli board list`, which produces - # outputs of the form: - """ Port Type Board Name FQBN Core /dev/ttyS4 Serial Port Unknown /dev/ttyUSB0 Serial Port (USB) Spresense SPRESENSE:spresense:spresense SPRESENSE:spresense """ + BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core") + + def _parse_boards_tabular_str(self, tabular_str): + str_rows = tabular_str.split("\n")[:-2] + header = str_rows[0] + indices = [header.index(h) for h in self.BOARD_LIST_HEADERS] + [len(header)] + + for str_row in str_rows[1:]: + parsed_row = [] + for cell_index in range(len(self.BOARD_LIST_HEADERS)): + start = indices[cell_index] + end = indices[cell_index + 1] + str_cell = str_row[start:end] + + # Remove trailing whitespace used for padding + parsed_row.append(str_cell.rstrip()) + yield parsed_row + def _auto_detect_port(self, options): list_cmd = [options["arduino_cli_cmd"], "board", "list"] - list_cmd_output = subprocess.check_output(list_cmd).decode("utf-8") - # Remove header and new lines at bottom - port_options = list_cmd_output.split("\n")[1:-2] + list_cmd_output = subprocess.run(list_cmd, capture_output=True).stdout.decode("utf-8") - # Select the first compatible board - fqbn = self._get_fqbn(options) - for port_option in port_options: - if fqbn in port_option: - return port_option.split(" ")[0] + desired_fqbn = self._get_fqbn(options) + for line in self._parse_boards_tabular_str(list_cmd_output): + if line[3] == desired_fqbn: + return line[0] # If no compatible boards, raise an error raise BoardAutodetectFailed() @@ -387,12 +424,7 @@ def flash(self, options): if options.get("verbose"): upload_cmd.append("--verbose") - output = subprocess.check_call(upload_cmd) - - if output == 2: - raise InvalidPortException() - elif output > 0: - raise SketchUploadException() + subprocess.run(upload_cmd) def open_transport(self, options): # Zephyr example doesn't throw an error in this case @@ -401,14 +433,13 @@ def open_transport(self, options): port = self._get_arduino_port(options) - # Wait for port to become available + # It takes a moment for the Arduino code to finish initializing + # and start communicating over serial for attempts in range(10): if any(serial.tools.list_ports.grep(port)): break time.sleep(0.5) - # TODO figure out why RPC serial communication times out 90% (not 100%) - # of the time on the Nano 33 BLE self._serial = serial.Serial(port, baudrate=115200, timeout=5) return server.TransportTimeouts( diff --git a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py index 3dd3688fe5b4..e576cc2c2d88 100644 --- a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py @@ -15,9 +15,9 @@ # specific language governing permissions and limitations # under the License. -from unittest import mock -from pathlib import Path import sys +from pathlib import Path +from unittest import mock import pytest @@ -64,6 +64,7 @@ def test_find_modified_include_path(self, mock_pathlib_path): BOARD_CONNECTED_OUTPUT = bytes( "Port Type Board Name FQBN Core \n" + "/dev/ttyACM1 Serial Port (USB) Wrong Arduino arduino:mbed_nano:nano33 arduino:mbed_nano\n" "/dev/ttyACM0 Serial Port (USB) Arduino Nano 33 BLE arduino:mbed_nano:nano33ble arduino:mbed_nano\n" "/dev/ttyS4 Serial Port Unknown \n" "\n", diff --git a/tests/micro/arduino/README.md b/tests/micro/arduino/README.md index f2c0535ff25a..78e63cabb7e2 100644 --- a/tests/micro/arduino/README.md +++ b/tests/micro/arduino/README.md @@ -22,19 +22,14 @@ all of the appropriate TVM dependencies installed. You can run the test with: ``` $ cd tvm/tests/micro/arduino -$ pytest test_arduino_workflow.py --platform spresense -$ pytest test_arduino_workflow.py --platform nano33ble +$ pytest --microtvm-platforms spresense ``` -By default, only project generation and compilation tests are run. If you -have compatible Arduino hardware connected, you can pass the flag -`--run-hardware-tests` to test board auto-detection and code execution: +Most of these tests require a supported Arduino board to be connected. +If you don't want to run these tests, you can pass the flag +`--test-build-only` to only test project generation and compilation. +To see the list of supported values for `----microtvm-platforms`, run: ``` -pytest test_arduino_workflow.py --platform spresense --run-hardware-tests -``` - -To see the list of supported values for `--platform`, run: -``` -$ pytest test_arduino_workflow.py --help +$ pytest --help ``` diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index f4e27d1458e3..19bba9b39536 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -19,7 +19,6 @@ import pathlib import pytest - import tvm.target.target # The models that should pass this configuration. Maps a short, identifying platform string to @@ -49,8 +48,8 @@ def pytest_addoption(parser): parser.addoption( "--microtvm-platforms", - default=["due"], - nargs="*", + nargs="+", + required=True, choices=PLATFORMS.keys(), help="Target platforms for microTVM tests.", ) @@ -60,18 +59,32 @@ def pytest_addoption(parser): help="Path to `arduino-cli` command for flashing device.", ) parser.addoption( - "--run-hardware-tests", + "--test-build-only", action="store_true", - help="Run tests that require physical hardware.", + help="Only run tests that don't require physical hardware.", ) parser.addoption( "--tvm-debug", action="store_true", default=False, - help="If set true, enable a debug session while the test is running. Before running the test, in a separate shell, you should run: ", + help="If given, enable a debug session while the test is running. Before running the test, in a separate shell, you should run: ", ) +def pytest_configure(config): + config.addinivalue_line( + "markers", "requires_hardware: mark test to run only when an Arduino board is connected" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--test-build-only"): + skip_hardware_tests = pytest.mark.skip(reason="--test-build-only was passed") + for item in items: + if "requires_hardware" in item.keywords: + item.add_marker(skip_hardware_tests) + + # We might do project generation differently for different boards in the future # (to take advantage of multiple cores / external memory / etc.), so all tests # are parameterized by board @@ -90,11 +103,6 @@ def tvm_debug(request): return request.config.getoption("--tvm-debug") -@pytest.fixture(scope="session") -def run_hardware_tests(request): - return request.config.getoption("--run-hardware-tests") - - def make_workspace_dir(test_name, platform): _, arduino_board = PLATFORMS[platform] filepath = pathlib.Path(__file__) diff --git a/tests/micro/arduino/test_arduino_rpc_server.py b/tests/micro/arduino/test_arduino_rpc_server.py index eee4a5623cff..fd656dbe38c3 100644 --- a/tests/micro/arduino/test_arduino_rpc_server.py +++ b/tests/micro/arduino/test_arduino_rpc_server.py @@ -15,38 +15,31 @@ # specific language governing permissions and limitations # under the License. +""" +This unit test simulates an autotuning workflow, where we: +1. Instantiate the Arduino RPC server project +2. Build and flash that project onto our target board + +""" + import datetime -import os import pathlib -import shutil import sys -import time import numpy as np import onnx -from PIL import Image import pytest -import tflite - import tvm +from PIL import Image from tvm import micro, relay from tvm.relay.testing import byoc import conftest -""" -This unit test simulates an autotuning workflow, where we: -1. Instantiate the Arduino RPC server project -2. Build and flash that project onto our target board - -""" # We'll make a new workspace for each test @pytest.fixture(scope="function") -def workspace_dir(platform, run_hardware_tests): - if not run_hardware_tests: - pytest.skip() - +def workspace_dir(platform): return conftest.make_workspace_dir("arduino_rpc_server", platform) @@ -89,6 +82,7 @@ def _make_add_sess(model, arduino_board, arduino_cli_cmd, workspace_dir, build_c # The same test code can be executed on both the QEMU simulation and on real hardware. @tvm.testing.requires_micro +@pytest.mark.requires_hardware def test_compile_runtime(platform, arduino_cli_cmd, tvm_debug, workspace_dir): """Test compiling the on-device runtime.""" @@ -110,10 +104,10 @@ def test_basic_add(sess): with _make_add_sess(model, arduino_board, arduino_cli_cmd, workspace_dir, build_config) as sess: test_basic_add(sess) - print(workspace_dir) @tvm.testing.requires_micro +@pytest.mark.requires_hardware def test_platform_timer(platform, arduino_cli_cmd, tvm_debug, workspace_dir): """Test compiling the on-device runtime.""" @@ -143,6 +137,7 @@ def test_basic_add(sess): @tvm.testing.requires_micro +@pytest.mark.requires_hardware def test_relay(platform, arduino_cli_cmd, tvm_debug, workspace_dir): """Testing a simple relay graph""" model, arduino_board = conftest.PLATFORMS[platform] @@ -176,13 +171,14 @@ def test_relay(platform, arduino_cli_cmd, tvm_debug, workspace_dir): @tvm.testing.requires_micro +@pytest.mark.requires_hardware def test_onnx(platform, arduino_cli_cmd, tvm_debug, workspace_dir): """Testing a simple ONNX model.""" model, arduino_board = conftest.PLATFORMS[platform] build_config = {"debug": tvm_debug} # Load test images. - this_dir = pathlib.Path(os.path.dirname(__file__)) + this_dir = pathlib.Path(__file__).parent testdata_dir = this_dir.parent / "testdata" digit_2 = Image.open(testdata_dir / "digit-2.jpg").resize((28, 28)) digit_2 = np.asarray(digit_2).astype("float32") @@ -263,6 +259,7 @@ def check_result( @tvm.testing.requires_micro +@pytest.mark.requires_hardware def test_byoc_microtvm(platform, arduino_cli_cmd, tvm_debug, workspace_dir): """This is a simple test case to check BYOC capabilities of microTVM""" model, arduino_board = conftest.PLATFORMS[platform] @@ -346,6 +343,7 @@ def _make_add_sess_with_shape( ], ) @tvm.testing.requires_micro +@pytest.mark.requires_hardware def test_rpc_large_array(platform, arduino_cli_cmd, tvm_debug, workspace_dir, shape): """Test large RPC array transfer.""" model, arduino_board = conftest.PLATFORMS[platform] diff --git a/tests/micro/arduino/test_arduino_workflow.py b/tests/micro/arduino/test_arduino_workflow.py index fddb47c8267c..18457763b9d0 100644 --- a/tests/micro/arduino/test_arduino_workflow.py +++ b/tests/micro/arduino/test_arduino_workflow.py @@ -19,7 +19,6 @@ import pathlib import shutil import sys -import time import pytest import tflite @@ -165,10 +164,7 @@ def test_compile_yes_no_project(project_dir, project, compiled_project): @pytest.fixture(scope="module") -def uploaded_project(compiled_project, run_hardware_tests): - if not run_hardware_tests: - pytest.skip() - +def uploaded_project(compiled_project): compiled_project.flash() return compiled_project @@ -186,9 +182,6 @@ def uploaded_project(compiled_project, run_hardware_tests): @pytest.fixture(scope="module") def serial_output(uploaded_project): - # Give time for the board to open a serial connection - time.sleep(1) - transport = uploaded_project.transport() transport.open() out = transport.read(2048, -1) @@ -212,6 +205,7 @@ def serial_output(uploaded_project): MAX_PREDICTION_DIFFERENCE = 2 +@pytest.mark.requires_hardware def test_project_inference_correctness(serial_output): predictions = {line[0]: line[2:] for line in serial_output} @@ -228,6 +222,7 @@ def test_project_inference_correctness(serial_output): MAX_INFERENCE_TIME_RANGE_US = 1000 +@pytest.mark.requires_hardware def test_project_inference_runtime(serial_output): runtimes_us = [line[1] for line in serial_output] diff --git a/tests/micro/zephyr/conftest.py b/tests/micro/zephyr/conftest.py index 51ed86568871..2b30401a90e9 100644 --- a/tests/micro/zephyr/conftest.py +++ b/tests/micro/zephyr/conftest.py @@ -89,8 +89,6 @@ def temp_dir(platform): _, zephyr_board = PLATFORMS[platform] parent_dir = pathlib.Path(os.path.dirname(__file__)) filename = os.path.splitext(os.path.basename(__file__))[0] - print(filename) - print("-----------------") board_workspace = ( parent_dir / f"workspace_{filename}_{zephyr_board}" diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index c6e6f65ce620..0793a96a457c 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -25,5 +25,5 @@ source tests/scripts/setup-pytest-env.sh make cython3 run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=host run_pytest ctypes python-microtvm-zephyr tests/micro/zephyr --microtvm-platforms=mps2_an521 -run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --microtvm-platforms=due -run_pytest ctypes python-microtvm-arduino-nano33ble tests/micro/arduino --microtvm-platforms=nano33ble +run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-only --microtvm-platforms=due +run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build-only --microtvm-platforms=nano33ble