From 93b10696c6fe564e55ce8e126645255ff8a07c2d Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 14 Sep 2021 15:41:01 -0700 Subject: [PATCH] [microTVM][AutoTVM] Fix autotvm bug and tests (#9003) * debuggging * cleanup and fix tutorial, zephyr and crt test * fix crt test * address comments --- python/tvm/micro/build.py | 15 ++++++++------- python/tvm/micro/session.py | 31 +++++++++--------------------- src/runtime/micro/micro_session.cc | 2 +- src/runtime/rpc/rpc_endpoint.cc | 5 ++++- src/runtime/rpc/rpc_endpoint.h | 6 +++++- tests/micro/zephyr/test_zephyr.py | 4 +++- tests/python/unittest/test_crt.py | 6 ++++-- tutorials/micro/micro_autotune.py | 8 ++++---- 8 files changed, 38 insertions(+), 39 deletions(-) diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 7da9daf958c6..9e278081933c 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -21,7 +21,9 @@ import logging import os import pathlib +import contextlib +from typing import Union from .._ffi import libinfo from .. import rpc as _rpc @@ -67,21 +69,24 @@ class AutoTvmModuleLoader: Parameters ---------- - template_project_dir : str + template_project_dir : Union[pathlib.Path, str] project template path project_options : dict project generation option """ - def __init__(self, template_project_dir: str, project_options: dict = None): + def __init__( + self, template_project_dir: Union[pathlib.Path, str], project_options: dict = None + ): self._project_options = project_options - if isinstance(template_project_dir, pathlib.Path): + if isinstance(template_project_dir, (pathlib.Path, str)): self._template_project_dir = str(template_project_dir) elif not isinstance(template_project_dir, str): raise TypeError(f"Incorrect type {type(template_project_dir)}.") + @contextlib.contextmanager def __call__(self, remote_kw, build_result): with open(build_result.filename, "rb") as build_file: build_result_bin = build_file.read() @@ -100,10 +105,6 @@ def __call__(self, remote_kw, build_result): ) system_lib = remote.get_function("runtime.SystemLib")() yield remote, system_lib - try: - remote.get_function("tvm.micro.destroy_micro_session")() - except tvm.error.TVMError as exception: - _LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1) def autotvm_build_func(): diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index abe7aff766e2..ced20b7ebfbf 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -130,6 +130,7 @@ def __enter__(self): int(timeouts.session_start_retry_timeout_sec * 1e6), int(timeouts.session_start_timeout_sec * 1e6), int(timeouts.session_established_timeout_sec * 1e6), + self._shutdown, ) ) self.device = self._rpc.cpu(0) @@ -143,6 +144,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback): """Tear down this session and associated RPC session resources.""" self.transport.__exit__(exc_type, exc_value, exc_traceback) + def _shutdown(self): + self.__exit__(None, None, None) + def lookup_remote_linked_param(mod, storage_id, template_tensor, device): """Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module. @@ -239,9 +243,6 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None): ) -RPC_SESSION = None - - @register_func("tvm.micro.compile_and_create_micro_session") def compile_and_create_micro_session( mod_src_bytes: bytes, @@ -264,7 +265,6 @@ def compile_and_create_micro_session( project_options: dict Options for the microTVM API Server contained in template_project_dir. """ - global RPC_SESSION temp_dir = utils.tempdir() # Keep temp directory for generate project @@ -277,7 +277,7 @@ def compile_and_create_micro_session( template_project = project.TemplateProject.from_directory(template_project_dir) generated_project = template_project.generate_project_from_mlf( model_library_format_path, - temp_dir / "generated-project", + str(temp_dir / "generated-project"), options=json.loads(project_options), ) except Exception as exception: @@ -288,20 +288,7 @@ def compile_and_create_micro_session( generated_project.flash() transport = generated_project.transport() - RPC_SESSION = Session(transport_context_manager=transport) - RPC_SESSION.__enter__() - return RPC_SESSION._rpc._sess - - -@register_func -def destroy_micro_session(): - """Destroy RPC session for microTVM autotune.""" - global RPC_SESSION - - if RPC_SESSION is not None: - exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None) - RPC_SESSION = None - if (exc_type, exc_value, traceback) != (None, None, None): - exc = exc_type(exc_value) # See PEP 3109 - exc.__traceback__ = traceback - raise exc + rpc_session = Session(transport_context_manager=transport) + # RPC exit is called by shutdown function. + rpc_session.__enter__() + return rpc_session._rpc._sess diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 2dcd928b24f8..9e6664ff5984 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -404,7 +404,7 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* throw std::runtime_error(ss.str()); } std::unique_ptr channel(micro_channel); - auto ep = RPCEndpoint::Create(std::move(channel), args[0], ""); + auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]); auto sess = CreateClientSession(ep); *rv = CreateRPCSessionModule(sess); }); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index e83f062795e4..2f1fc54f39d0 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -691,11 +691,13 @@ void RPCEndpoint::Init() { * the key to modify their behavior. */ std::shared_ptr RPCEndpoint::Create(std::unique_ptr channel, - std::string name, std::string remote_key) { + std::string name, std::string remote_key, + TypedPackedFunc fshutdown) { std::shared_ptr endpt = std::make_shared(); endpt->channel_ = std::move(channel); endpt->name_ = std::move(name); endpt->remote_key_ = std::move(remote_key); + endpt->fshutdown_ = fshutdown; endpt->Init(); return endpt; } @@ -734,6 +736,7 @@ void RPCEndpoint::ServerLoop() { (*f)(); } channel_.reset(nullptr); + if (fshutdown_ != nullptr) fshutdown_(); } int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 7c11a1aeac01..f6784faba0f6 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -161,11 +161,13 @@ class RPCEndpoint { * \param channel The communication channel. * \param name The local name of the session, used for debug * \param remote_key The remote key of the session + * \param fshutdown The shutdown Packed function * if remote_key equals "%toinit", we need to re-intialize * it by event handler. */ static std::shared_ptr Create(std::unique_ptr channel, std::string name, - std::string remote_key); + std::string remote_key, + TypedPackedFunc fshutdown = nullptr); private: class EventHandler; @@ -190,6 +192,8 @@ class RPCEndpoint { std::string name_; // The remote key std::string remote_key_; + // The shutdown Packed Function + TypedPackedFunc fshutdown_; }; /*! diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 8554eeb8d18e..d2d5522b1a0a 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -436,7 +436,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): do_fork=True, build_func=tvm.micro.autotvm_build_func, ) - runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -457,6 +457,8 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): si_prefix="M", ) + assert tuner.best_flops > 0 + # Build without tuning with pass_context: lowered = tvm.relay.build(mod, target=target, params=params) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 84395e877e26..1514c51b5af0 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -254,7 +254,7 @@ def test_autotune(): inputs = {"data": input_data} target = tvm.target.target.micro("host") - template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") + template_project_dir = pathlib.Path(tvm.micro.get_standalone_crt_dir()) / "template" / "host" pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}) with pass_context: @@ -271,7 +271,7 @@ def test_autotune(): do_fork=True, build_func=tvm.micro.autotvm_build_func, ) - runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader) measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -292,6 +292,8 @@ def test_autotune(): si_prefix="M", ) + assert tuner.best_flops > 0 + # Build without tuning with pass_context: lowered = tvm.relay.build(mod, target=TARGET, params=params) diff --git a/tutorials/micro/micro_autotune.py b/tutorials/micro/micro_autotune.py index 136bcfeaec80..f89432ff01cf 100644 --- a/tutorials/micro/micro_autotune.py +++ b/tutorials/micro/micro_autotune.py @@ -125,7 +125,7 @@ do_fork=True, build_func=tvm.micro.autotvm_build_func, ) -runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) +runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -146,7 +146,7 @@ # do_fork=False, # build_func=tvm.micro.autotvm_build_func, # ) -# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) +# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) # measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -162,7 +162,7 @@ n_trial=num_trials, measure_option=measure_option, callbacks=[ - tvm.autotvm.callback.log_to_file("microtvm_autotune.log"), + tvm.autotvm.callback.log_to_file("microtvm_autotune.log.txt"), tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"), ], si_prefix="M", @@ -214,7 +214,7 @@ ########################## # Once autotuning completes, you can time execution of the entire program using the Debug Runtime: -with tvm.autotvm.apply_history_best("microtvm_autotune.log"): +with tvm.autotvm.apply_history_best("microtvm_autotune.log.txt"): with pass_context: lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params)