Skip to content

Commit

Permalink
[microTVM][AutoTVM] Fix autotvm bug and tests (apache#9003)
Browse files Browse the repository at this point in the history
* debuggging

* cleanup and fix tutorial, zephyr and crt test

* fix crt test

* address comments
  • Loading branch information
mehrdadh authored and ylc committed Sep 29, 2021
1 parent 2fbaac6 commit 93b1069
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 39 deletions.
15 changes: 8 additions & 7 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import logging
import os
import pathlib
import contextlib

from typing import Union
from .._ffi import libinfo
from .. import rpc as _rpc

Expand Down Expand Up @@ -67,21 +69,24 @@ class AutoTvmModuleLoader:
Parameters
----------
template_project_dir : str
template_project_dir : Union[pathlib.Path, str]
project template path
project_options : dict
project generation option
"""

def __init__(self, template_project_dir: str, project_options: dict = None):
def __init__(
self, template_project_dir: Union[pathlib.Path, str], project_options: dict = None
):
self._project_options = project_options

if isinstance(template_project_dir, pathlib.Path):
if isinstance(template_project_dir, (pathlib.Path, str)):
self._template_project_dir = str(template_project_dir)
elif not isinstance(template_project_dir, str):
raise TypeError(f"Incorrect type {type(template_project_dir)}.")

@contextlib.contextmanager
def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
build_result_bin = build_file.read()
Expand All @@ -100,10 +105,6 @@ def __call__(self, remote_kw, build_result):
)
system_lib = remote.get_function("runtime.SystemLib")()
yield remote, system_lib
try:
remote.get_function("tvm.micro.destroy_micro_session")()
except tvm.error.TVMError as exception:
_LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1)


def autotvm_build_func():
Expand Down
31 changes: 9 additions & 22 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __enter__(self):
int(timeouts.session_start_retry_timeout_sec * 1e6),
int(timeouts.session_start_timeout_sec * 1e6),
int(timeouts.session_established_timeout_sec * 1e6),
self._shutdown,
)
)
self.device = self._rpc.cpu(0)
Expand All @@ -143,6 +144,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
"""Tear down this session and associated RPC session resources."""
self.transport.__exit__(exc_type, exc_value, exc_traceback)

def _shutdown(self):
self.__exit__(None, None, None)


def lookup_remote_linked_param(mod, storage_id, template_tensor, device):
"""Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module.
Expand Down Expand Up @@ -239,9 +243,6 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
)


RPC_SESSION = None


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
Expand All @@ -264,7 +265,6 @@ def compile_and_create_micro_session(
project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""
global RPC_SESSION

temp_dir = utils.tempdir()
# Keep temp directory for generate project
Expand All @@ -277,7 +277,7 @@ def compile_and_create_micro_session(
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
temp_dir / "generated-project",
str(temp_dir / "generated-project"),
options=json.loads(project_options),
)
except Exception as exception:
Expand All @@ -288,20 +288,7 @@ def compile_and_create_micro_session(
generated_project.flash()
transport = generated_project.transport()

RPC_SESSION = Session(transport_context_manager=transport)
RPC_SESSION.__enter__()
return RPC_SESSION._rpc._sess


@register_func
def destroy_micro_session():
"""Destroy RPC session for microTVM autotune."""
global RPC_SESSION

if RPC_SESSION is not None:
exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None)
RPC_SESSION = None
if (exc_type, exc_value, traceback) != (None, None, None):
exc = exc_type(exc_value) # See PEP 3109
exc.__traceback__ = traceback
raise exc
rpc_session = Session(transport_context_manager=transport)
# RPC exit is called by shutdown function.
rpc_session.__enter__()
return rpc_session._rpc._sess
2 changes: 1 addition & 1 deletion src/runtime/micro/micro_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue*
throw std::runtime_error(ss.str());
}
std::unique_ptr<RPCChannel> channel(micro_channel);
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "");
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]);
auto sess = CreateClientSession(ep);
*rv = CreateRPCSessionModule(sess);
});
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,13 @@ void RPCEndpoint::Init() {
* the key to modify their behavior.
*/
std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(std::unique_ptr<RPCChannel> channel,
std::string name, std::string remote_key) {
std::string name, std::string remote_key,
TypedPackedFunc<void()> fshutdown) {
std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>();
endpt->channel_ = std::move(channel);
endpt->name_ = std::move(name);
endpt->remote_key_ = std::move(remote_key);
endpt->fshutdown_ = fshutdown;
endpt->Init();
return endpt;
}
Expand Down Expand Up @@ -734,6 +736,7 @@ void RPCEndpoint::ServerLoop() {
(*f)();
}
channel_.reset(nullptr);
if (fshutdown_ != nullptr) fshutdown_();
}

int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) {
Expand Down
6 changes: 5 additions & 1 deletion src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ class RPCEndpoint {
* \param channel The communication channel.
* \param name The local name of the session, used for debug
* \param remote_key The remote key of the session
* \param fshutdown The shutdown Packed function
* if remote_key equals "%toinit", we need to re-intialize
* it by event handler.
*/
static std::shared_ptr<RPCEndpoint> Create(std::unique_ptr<RPCChannel> channel, std::string name,
std::string remote_key);
std::string remote_key,
TypedPackedFunc<void()> fshutdown = nullptr);

private:
class EventHandler;
Expand All @@ -190,6 +192,8 @@ class RPCEndpoint {
std::string name_;
// The remote key
std::string remote_key_;
// The shutdown Packed Function
TypedPackedFunc<void()> fshutdown_;
};

/*!
Expand Down
4 changes: 3 additions & 1 deletion tests/micro/zephyr/test_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -457,6 +457,8 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
si_prefix="M",
)

assert tuner.best_flops > 0

# Build without tuning
with pass_context:
lowered = tvm.relay.build(mod, target=target, params=params)
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_autotune():
inputs = {"data": input_data}

target = tvm.target.target.micro("host")
template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host")
template_project_dir = pathlib.Path(tvm.micro.get_standalone_crt_dir()) / "template" / "host"

pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True})
with pass_context:
Expand All @@ -271,7 +271,7 @@ def test_autotune():
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -292,6 +292,8 @@ def test_autotune():
si_prefix="M",
)

assert tuner.best_flops > 0

# Build without tuning
with pass_context:
lowered = tvm.relay.build(mod, target=TARGET, params=params)
Expand Down
8 changes: 4 additions & 4 deletions tutorials/micro/micro_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -146,7 +146,7 @@
# do_fork=False,
# build_func=tvm.micro.autotvm_build_func,
# )
# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

# measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -162,7 +162,7 @@
n_trial=num_trials,
measure_option=measure_option,
callbacks=[
tvm.autotvm.callback.log_to_file("microtvm_autotune.log"),
tvm.autotvm.callback.log_to_file("microtvm_autotune.log.txt"),
tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"),
],
si_prefix="M",
Expand Down Expand Up @@ -214,7 +214,7 @@
##########################
# Once autotuning completes, you can time execution of the entire program using the Debug Runtime:

with tvm.autotvm.apply_history_best("microtvm_autotune.log"):
with tvm.autotvm.apply_history_best("microtvm_autotune.log.txt"):
with pass_context:
lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params)

Expand Down

0 comments on commit 93b1069

Please sign in to comment.