From f95c459a9fae8f757f9ce7f3797280770b39b348 Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Mon, 13 Sep 2021 16:24:45 -0700 Subject: [PATCH] cleanup and fix tutorial, zephyr and crt test --- python/tvm/autotvm/measure/measure_methods.py | 3 -- python/tvm/contrib/popen_pool.py | 4 +-- python/tvm/micro/build.py | 11 ------- python/tvm/micro/session.py | 31 ++++++------------- src/runtime/micro/micro_session.cc | 2 +- src/runtime/rpc/rpc_endpoint.cc | 5 ++- src/runtime/rpc/rpc_endpoint.h | 6 +++- tests/micro/zephyr/test_zephyr.py | 4 ++- tests/python/unittest/test_crt.py | 4 ++- tutorials/micro/micro_autotune.py | 8 ++--- 10 files changed, 30 insertions(+), 48 deletions(-) diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index b4f1b44b6f46..efe45daa1464 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -359,9 +359,6 @@ def run(self, measure_inputs, build_results): res = future.result() results.append(res) except Exception as ex: # pylint: disable=broad-except - # import pdb; pdb.set_trace() - logging.debug("Mehrdad") - logging.debug(f"exception: {str(ex)}") results.append( MeasureResult( (str(ex),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time() diff --git a/python/tvm/contrib/popen_pool.py b/python/tvm/contrib/popen_pool.py index a1a810700c4a..907231c1a9fa 100644 --- a/python/tvm/contrib/popen_pool.py +++ b/python/tvm/contrib/popen_pool.py @@ -156,10 +156,8 @@ def _start(self): cmd += [str(worker_read_handle), str(worker_write_handle)] self._proc = subprocess.Popen(cmd, close_fds=False) else: - import logging - logging.debug("mehrdad popen") cmd += [str(worker_read), str(worker_write)] - self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write), stdout=subprocess.STDOUT) + self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write)) # close worker side of the pipe os.close(worker_read) diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 8dc2b9180226..bef1e32fe07e 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -89,13 +89,6 @@ def __call__(self, remote_kw, build_result): build_result_bin = build_file.read() tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"]) - import sys - with open("mehrdad_log.log", "w") as f: - f.write(f"mehrdad: {str(type(self._template_project_dir))}\n") - f.write(f"mehrdad: {self._template_project_dir}\n") - # import pdb; pdb.set_trace() - # sys.stdout.write("merhdad type:") - # sys.stderr.write(f"mehrdad: {str(type(self._template_project_dir))}") remote = tracker.request( remote_kw["device_key"], priority=remote_kw["priority"], @@ -109,10 +102,6 @@ def __call__(self, remote_kw, build_result): ) system_lib = remote.get_function("runtime.SystemLib")() yield remote, system_lib - try: - remote.get_function("tvm.micro.destroy_micro_session")() - except tvm.error.TVMError as exception: - _LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1) def autotvm_build_func(): diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index abe7aff766e2..ced20b7ebfbf 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -130,6 +130,7 @@ def __enter__(self): int(timeouts.session_start_retry_timeout_sec * 1e6), int(timeouts.session_start_timeout_sec * 1e6), int(timeouts.session_established_timeout_sec * 1e6), + self._shutdown, ) ) self.device = self._rpc.cpu(0) @@ -143,6 +144,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback): """Tear down this session and associated RPC session resources.""" self.transport.__exit__(exc_type, exc_value, exc_traceback) + def _shutdown(self): + self.__exit__(None, None, None) + def lookup_remote_linked_param(mod, storage_id, template_tensor, device): """Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module. @@ -239,9 +243,6 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None): ) -RPC_SESSION = None - - @register_func("tvm.micro.compile_and_create_micro_session") def compile_and_create_micro_session( mod_src_bytes: bytes, @@ -264,7 +265,6 @@ def compile_and_create_micro_session( project_options: dict Options for the microTVM API Server contained in template_project_dir. """ - global RPC_SESSION temp_dir = utils.tempdir() # Keep temp directory for generate project @@ -277,7 +277,7 @@ def compile_and_create_micro_session( template_project = project.TemplateProject.from_directory(template_project_dir) generated_project = template_project.generate_project_from_mlf( model_library_format_path, - temp_dir / "generated-project", + str(temp_dir / "generated-project"), options=json.loads(project_options), ) except Exception as exception: @@ -288,20 +288,7 @@ def compile_and_create_micro_session( generated_project.flash() transport = generated_project.transport() - RPC_SESSION = Session(transport_context_manager=transport) - RPC_SESSION.__enter__() - return RPC_SESSION._rpc._sess - - -@register_func -def destroy_micro_session(): - """Destroy RPC session for microTVM autotune.""" - global RPC_SESSION - - if RPC_SESSION is not None: - exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None) - RPC_SESSION = None - if (exc_type, exc_value, traceback) != (None, None, None): - exc = exc_type(exc_value) # See PEP 3109 - exc.__traceback__ = traceback - raise exc + rpc_session = Session(transport_context_manager=transport) + # RPC exit is called by shutdown function. + rpc_session.__enter__() + return rpc_session._rpc._sess diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 2dcd928b24f8..9e6664ff5984 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -404,7 +404,7 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* throw std::runtime_error(ss.str()); } std::unique_ptr channel(micro_channel); - auto ep = RPCEndpoint::Create(std::move(channel), args[0], ""); + auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]); auto sess = CreateClientSession(ep); *rv = CreateRPCSessionModule(sess); }); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index e83f062795e4..2f1fc54f39d0 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -691,11 +691,13 @@ void RPCEndpoint::Init() { * the key to modify their behavior. */ std::shared_ptr RPCEndpoint::Create(std::unique_ptr channel, - std::string name, std::string remote_key) { + std::string name, std::string remote_key, + TypedPackedFunc fshutdown) { std::shared_ptr endpt = std::make_shared(); endpt->channel_ = std::move(channel); endpt->name_ = std::move(name); endpt->remote_key_ = std::move(remote_key); + endpt->fshutdown_ = fshutdown; endpt->Init(); return endpt; } @@ -734,6 +736,7 @@ void RPCEndpoint::ServerLoop() { (*f)(); } channel_.reset(nullptr); + if (fshutdown_ != nullptr) fshutdown_(); } int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 7c11a1aeac01..f6784faba0f6 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -161,11 +161,13 @@ class RPCEndpoint { * \param channel The communication channel. * \param name The local name of the session, used for debug * \param remote_key The remote key of the session + * \param fshutdown The shutdown Packed function * if remote_key equals "%toinit", we need to re-intialize * it by event handler. */ static std::shared_ptr Create(std::unique_ptr channel, std::string name, - std::string remote_key); + std::string remote_key, + TypedPackedFunc fshutdown = nullptr); private: class EventHandler; @@ -190,6 +192,8 @@ class RPCEndpoint { std::string name_; // The remote key std::string remote_key_; + // The shutdown Packed Function + TypedPackedFunc fshutdown_; }; /*! diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 8554eeb8d18e..d2d5522b1a0a 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -436,7 +436,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): do_fork=True, build_func=tvm.micro.autotvm_build_func, ) - runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -457,6 +457,8 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug): si_prefix="M", ) + assert tuner.best_flops > 0 + # Build without tuning with pass_context: lowered = tvm.relay.build(mod, target=target, params=params) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 5c6eb922fa17..989357a5589a 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -265,7 +265,7 @@ def test_autotune(): do_fork=True, build_func=tvm.micro.autotvm_build_func, ) - runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader) measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -286,6 +286,8 @@ def test_autotune(): si_prefix="M", ) + assert tuner.best_flops > 0 + # Build without tuning with pass_context: lowered = tvm.relay.build(mod, target=TARGET, params=params) diff --git a/tutorials/micro/micro_autotune.py b/tutorials/micro/micro_autotune.py index 136bcfeaec80..a10fb23158ab 100644 --- a/tutorials/micro/micro_autotune.py +++ b/tutorials/micro/micro_autotune.py @@ -125,7 +125,7 @@ do_fork=True, build_func=tvm.micro.autotvm_build_func, ) -runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) +runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader) measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -146,7 +146,7 @@ # do_fork=False, # build_func=tvm.micro.autotvm_build_func, # ) -# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) +# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader) # measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) @@ -162,7 +162,7 @@ n_trial=num_trials, measure_option=measure_option, callbacks=[ - tvm.autotvm.callback.log_to_file("microtvm_autotune.log"), + tvm.autotvm.callback.log_to_file("microtvm_autotune.log.txt"), tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"), ], si_prefix="M", @@ -214,7 +214,7 @@ ########################## # Once autotuning completes, you can time execution of the entire program using the Debug Runtime: -with tvm.autotvm.apply_history_best("microtvm_autotune.log"): +with tvm.autotvm.apply_history_best("microtvm_autotune.log.txt"): with pass_context: lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params)