Skip to content

Commit

Permalink
fix autotvm test
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Sep 16, 2021
1 parent 820a57c commit 8a64854
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _is_qemu(cls, options):

@classmethod
def _has_fpu(cls, zephyr_board):
fpu_boards = ([name for name, board in BOARD_PROPERTIES.items() if board["fpu"]],)
fpu_boards = [name for name, board in BOARD_PROPERTIES.items() if board["fpu"]]
return zephyr_board in fpu_boards

def flash(self, options):
Expand Down
26 changes: 14 additions & 12 deletions python/tvm/micro/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Defines glue wrappers around the Project API which mate to TVM interfaces."""

import pathlib
import typing
from typing import Union

from .. import __version__
from ..contrib import utils
Expand Down Expand Up @@ -64,7 +64,7 @@ class GeneratedProject:
"""Defines a glue interface to interact with a generated project through the API server."""

@classmethod
def from_directory(cls, project_dir: typing.Union[pathlib.Path, str], options: dict):
def from_directory(cls, project_dir: Union[pathlib.Path, str], options: dict):
return cls(client.instantiate_from_dir(project_dir), options)

def __init__(self, api_client, options):
Expand Down Expand Up @@ -101,7 +101,17 @@ def __init__(self, api_client):
if not self._info["is_template"]:
raise NotATemplateProjectError()

def _check_project_options(self, options: dict):
"""Check if options are valid ProjectOptions"""
valid_project_options = [item["name"] for item in self.info()["project_options"]]

if not all(element in list(valid_project_options) for element in list(options)):
raise ValueError(
f"options:{list(options)} include none valid ProjectOptions. Here is a list of valid options:{list(valid_project_options)}."
)

def generate_project_from_mlf(self, model_library_format_path, project_dir, options):
self._check_project_options(options)
self._api_client.generate_project(
model_library_format_path=str(model_library_format_path),
standalone_crt_dir=get_standalone_crt_dir(),
Expand All @@ -124,9 +134,9 @@ def generate_project(self, graph_executor_factory, project_dir, options):


def generate_project(
template_project_dir: typing.Union[pathlib.Path, str],
template_project_dir: Union[pathlib.Path, str],
module: ExportableModule,
generated_project_dir: typing.Union[pathlib.Path, str],
generated_project_dir: Union[pathlib.Path, str],
options: dict = None,
):
"""Generate a project for an embedded platform that contains the given model.
Expand All @@ -153,12 +163,4 @@ def generate_project(
A class that wraps the generated project and which can be used to further interact with it.
"""
template = TemplateProject.from_directory(str(template_project_dir))

# check if options are valid
valid_project_options = [item["name"] for item in template.info()["project_options"]]
if not all(element in list(valid_project_options) for element in list(options)):
raise ValueError(
f"options:{list(options)} include none valid ProjectOptions. Here is a list of valid options:{list(valid_project_options)}."
)

return template.generate_project(module, str(generated_project_dir), options)
8 changes: 4 additions & 4 deletions tests/micro/zephyr/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ def zephyr_boards() -> dict:
ZEPHYR_BOARDS = zephyr_boards()


def qemu_boards() -> list:
"""Returns a list of QEMU Zephyr boards."""
def qemu_boards(board: str):
"""Returns True if board is QEMU."""
with open(BOARD_JSON_PATH) as f:
board_properties = json.load(f)

qemu_boards = [name for name, board in board_properties.items() if board["is_qemu"]]
return qemu_boards
return board in qemu_boards


def has_fpu(board):
def has_fpu(board: str):
"""Returns True if board has FPU."""
with open(BOARD_JSON_PATH) as f:
board_properties = json.load(f)
Expand Down
12 changes: 6 additions & 6 deletions tests/micro/zephyr/test_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def _make_sess_from_op(

return _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config)


def _make_session(temp_dir, zephyr_board, west_cmd, mod, build_config):
stack_size = None
if zephyr_board in conftest.qemu_boards():
if conftest.qemu_boards(zephyr_board):
stack_size = 1536

project = tvm.micro.generate_project(
Expand Down Expand Up @@ -416,9 +417,9 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip()
)
template_project_dir = repo_root / "apps" / "microtvm" / "zephyr" / "template_project"

stack_size = None
if board in conftest.qemu_boards():
if conftest.qemu_boards(board):
stack_size = 1536

module_loader = tvm.micro.AutoTvmModuleLoader(
Expand All @@ -432,7 +433,6 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
},
)
builder = tvm.autotvm.LocalBuilder(
timeout=100,
n_parallel=1,
build_kwargs={"build_option": {"tir.disable_vectorize": True}},
do_fork=True,
Expand All @@ -446,7 +446,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
if log_path.exists():
log_path.unlink()

n_trial = 1
n_trial = 10
for task in tasks:
tuner = tvm.autotvm.tuner.GATuner(task)
tuner.tune(
Expand All @@ -466,7 +466,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
lowered = tvm.relay.build(mod, target=target, params=params)

temp_dir = utils.tempdir()
with _make_session(temp_dir, board, west_cmd, mod, build_config) as session:
with _make_session(temp_dir, board, west_cmd, lowered, build_config) as session:
graph_mod = tvm.micro.create_local_graph_executor(
lowered.get_graph_json(), session.get_system_lib(), session.device
)
Expand Down

0 comments on commit 8a64854

Please sign in to comment.