Skip to content

Commit

Permalink
cleanup and fix tutorial, zephyr and crt test
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Sep 13, 2021
1 parent 3628fc1 commit f95c459
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 48 deletions.
3 changes: 0 additions & 3 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,6 @@ def run(self, measure_inputs, build_results):
res = future.result()
results.append(res)
except Exception as ex: # pylint: disable=broad-except
# import pdb; pdb.set_trace()
logging.debug("Mehrdad")
logging.debug(f"exception: {str(ex)}")
results.append(
MeasureResult(
(str(ex),), MeasureErrorNo.RUN_TIMEOUT, self.timeout, time.time()
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,8 @@ def _start(self):
cmd += [str(worker_read_handle), str(worker_write_handle)]
self._proc = subprocess.Popen(cmd, close_fds=False)
else:
import logging
logging.debug("mehrdad popen")
cmd += [str(worker_read), str(worker_write)]
self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write), stdout=subprocess.STDOUT)
self._proc = subprocess.Popen(cmd, pass_fds=(worker_read, worker_write))

# close worker side of the pipe
os.close(worker_read)
Expand Down
11 changes: 0 additions & 11 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,6 @@ def __call__(self, remote_kw, build_result):
build_result_bin = build_file.read()

tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"])
import sys
with open("mehrdad_log.log", "w") as f:
f.write(f"mehrdad: {str(type(self._template_project_dir))}\n")
f.write(f"mehrdad: {self._template_project_dir}\n")
# import pdb; pdb.set_trace()
# sys.stdout.write("merhdad type:")
# sys.stderr.write(f"mehrdad: {str(type(self._template_project_dir))}")
remote = tracker.request(
remote_kw["device_key"],
priority=remote_kw["priority"],
Expand All @@ -109,10 +102,6 @@ def __call__(self, remote_kw, build_result):
)
system_lib = remote.get_function("runtime.SystemLib")()
yield remote, system_lib
try:
remote.get_function("tvm.micro.destroy_micro_session")()
except tvm.error.TVMError as exception:
_LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1)


def autotvm_build_func():
Expand Down
31 changes: 9 additions & 22 deletions python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __enter__(self):
int(timeouts.session_start_retry_timeout_sec * 1e6),
int(timeouts.session_start_timeout_sec * 1e6),
int(timeouts.session_established_timeout_sec * 1e6),
self._shutdown,
)
)
self.device = self._rpc.cpu(0)
Expand All @@ -143,6 +144,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
"""Tear down this session and associated RPC session resources."""
self.transport.__exit__(exc_type, exc_value, exc_traceback)

def _shutdown(self):
self.__exit__(None, None, None)


def lookup_remote_linked_param(mod, storage_id, template_tensor, device):
"""Lookup a parameter that has been pre-linked into a remote (i.e. over RPC) Module.
Expand Down Expand Up @@ -239,9 +243,6 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
)


RPC_SESSION = None


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
Expand All @@ -264,7 +265,6 @@ def compile_and_create_micro_session(
project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""
global RPC_SESSION

temp_dir = utils.tempdir()
# Keep temp directory for generate project
Expand All @@ -277,7 +277,7 @@ def compile_and_create_micro_session(
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
temp_dir / "generated-project",
str(temp_dir / "generated-project"),
options=json.loads(project_options),
)
except Exception as exception:
Expand All @@ -288,20 +288,7 @@ def compile_and_create_micro_session(
generated_project.flash()
transport = generated_project.transport()

RPC_SESSION = Session(transport_context_manager=transport)
RPC_SESSION.__enter__()
return RPC_SESSION._rpc._sess


@register_func
def destroy_micro_session():
"""Destroy RPC session for microTVM autotune."""
global RPC_SESSION

if RPC_SESSION is not None:
exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None)
RPC_SESSION = None
if (exc_type, exc_value, traceback) != (None, None, None):
exc = exc_type(exc_value) # See PEP 3109
exc.__traceback__ = traceback
raise exc
rpc_session = Session(transport_context_manager=transport)
# RPC exit is called by shutdown function.
rpc_session.__enter__()
return rpc_session._rpc._sess
2 changes: 1 addition & 1 deletion src/runtime/micro/micro_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue*
throw std::runtime_error(ss.str());
}
std::unique_ptr<RPCChannel> channel(micro_channel);
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "");
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]);
auto sess = CreateClientSession(ep);
*rv = CreateRPCSessionModule(sess);
});
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,13 @@ void RPCEndpoint::Init() {
* the key to modify their behavior.
*/
std::shared_ptr<RPCEndpoint> RPCEndpoint::Create(std::unique_ptr<RPCChannel> channel,
std::string name, std::string remote_key) {
std::string name, std::string remote_key,
TypedPackedFunc<void()> fshutdown) {
std::shared_ptr<RPCEndpoint> endpt = std::make_shared<RPCEndpoint>();
endpt->channel_ = std::move(channel);
endpt->name_ = std::move(name);
endpt->remote_key_ = std::move(remote_key);
endpt->fshutdown_ = fshutdown;
endpt->Init();
return endpt;
}
Expand Down Expand Up @@ -734,6 +736,7 @@ void RPCEndpoint::ServerLoop() {
(*f)();
}
channel_.reset(nullptr);
if (fshutdown_ != nullptr) fshutdown_();
}

int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) {
Expand Down
6 changes: 5 additions & 1 deletion src/runtime/rpc/rpc_endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ class RPCEndpoint {
* \param channel The communication channel.
* \param name The local name of the session, used for debug
* \param remote_key The remote key of the session
* \param fshutdown The shutdown Packed function
* if remote_key equals "%toinit", we need to re-intialize
* it by event handler.
*/
static std::shared_ptr<RPCEndpoint> Create(std::unique_ptr<RPCChannel> channel, std::string name,
std::string remote_key);
std::string remote_key,
TypedPackedFunc<void()> fshutdown = nullptr);

private:
class EventHandler;
Expand All @@ -190,6 +192,8 @@ class RPCEndpoint {
std::string name_;
// The remote key
std::string remote_key_;
// The shutdown Packed Function
TypedPackedFunc<void()> fshutdown_;
};

/*!
Expand Down
4 changes: 3 additions & 1 deletion tests/micro/zephyr/test_zephyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -457,6 +457,8 @@ def test_autotune_conv2d(temp_dir, board, west_cmd, tvm_debug):
si_prefix="M",
)

assert tuner.best_flops > 0

# Build without tuning
with pass_context:
lowered = tvm.relay.build(mod, target=target, params=params)
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_autotune():
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -286,6 +286,8 @@ def test_autotune():
si_prefix="M",
)

assert tuner.best_flops > 0

# Build without tuning
with pass_context:
lowered = tvm.relay.build(mod, target=TARGET, params=params)
Expand Down
8 changes: 4 additions & 4 deletions tutorials/micro/micro_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
do_fork=True,
build_func=tvm.micro.autotvm_build_func,
)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
runner = tvm.autotvm.LocalRunner(number=1, repeat=1, module_loader=module_loader)

measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -146,7 +146,7 @@
# do_fork=False,
# build_func=tvm.micro.autotvm_build_func,
# )
# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=100, module_loader=module_loader)

# measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)

Expand All @@ -162,7 +162,7 @@
n_trial=num_trials,
measure_option=measure_option,
callbacks=[
tvm.autotvm.callback.log_to_file("microtvm_autotune.log"),
tvm.autotvm.callback.log_to_file("microtvm_autotune.log.txt"),
tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"),
],
si_prefix="M",
Expand Down Expand Up @@ -214,7 +214,7 @@
##########################
# Once autotuning completes, you can time execution of the entire program using the Debug Runtime:

with tvm.autotvm.apply_history_best("microtvm_autotune.log"):
with tvm.autotvm.apply_history_best("microtvm_autotune.log.txt"):
with pass_context:
lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params)

Expand Down

0 comments on commit f95c459

Please sign in to comment.