diff --git a/python/tvm/autotvm/measure/__init__.py b/python/tvm/autotvm/measure/__init__.py index 0c32ae0ca9bf..c4c0dc92b116 100644 --- a/python/tvm/autotvm/measure/__init__.py +++ b/python/tvm/autotvm/measure/__init__.py @@ -23,6 +23,12 @@ measure_option, create_measure_batch, ) -from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote +from .measure_methods import ( + LocalBuilder, + LocalRunner, + RPCRunner, + default_module_loader, + request_remote, +) from .executor import Executor from .local_executor import LocalExecutor diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index ffe4b97e33db..62fd811dc1ec 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -22,11 +22,13 @@ remote devices, recording the running time costs, and checking the correctness of the output. """ +import contextlib import logging import shutil import os import threading import time +import typing from random import getrandbits from collections import namedtuple import tempfile @@ -199,6 +201,9 @@ class RPCRunner(Runner): its actual latency during end-to-end inference. To make this option effective, the argument `number` should also be set to 1. This is only has effect on CPU task. + module_loader : ModuleLoader + If given, a context manager that loads the module to be timed into the remote runtime. + If not given, default_module_loader is used. """ def __init__( @@ -214,6 +219,7 @@ def __init__( min_repeat_ms=0, cooldown_interval=0.1, enable_cpu_cache_flush=False, + module_loader=None, ): super(RPCRunner, self).__init__(timeout, n_parallel) @@ -229,6 +235,7 @@ def __init__( self.enable_cpu_cache_flush = enable_cpu_cache_flush self.cooldown_interval = cooldown_interval + self.module_loader = module_loader self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1)) @@ -280,6 +287,11 @@ def run(self, measure_inputs, build_results): for measure_inp, build_res in zip( measure_inputs[i : i + self.n_parallel], build_results[i : i + self.n_parallel] ): + module_loader = ( + self.module_loader + if self.module_loader is not None + else default_module_loader() + ) ret = self.executor.submit( run_through_rpc, measure_inp, @@ -290,6 +302,7 @@ def run(self, measure_inputs, build_results): self.cooldown_interval, remote_args, self.enable_cpu_cache_flush, + module_loader, ) futures.append(ret) @@ -352,6 +365,7 @@ def __init__( min_repeat_ms=0, cooldown_interval=0.1, enable_cpu_cache_flush=False, + module_loader=None, ): super(LocalRunner, self).__init__( "", @@ -365,6 +379,7 @@ def __init__( min_repeat_ms=min_repeat_ms, cooldown_interval=cooldown_interval, enable_cpu_cache_flush=enable_cpu_cache_flush, + module_loader=module_loader, ) self.tracker = None self.server = None @@ -473,6 +488,11 @@ def __call__(self, measure_input, tmp_dir, **kwargs): return BuildResult(filename, arg_info, None, time.time() - tic) +ModuleLoader = typing.Callable[ + [dict, dict], typing.ContextManager[typing.Tuple[tvm.rpc.RPCSession, tvm.runtime.Module]] +] + + def run_through_rpc( measure_input, build_result, @@ -480,8 +500,9 @@ def run_through_rpc( repeat, min_repeat_ms, cooldown_interval, - remote_args, + remote_kwargs, enable_cpu_cache_flush=False, + module_loader=None, ): """Run a generated library through rpc @@ -509,14 +530,16 @@ def run_through_rpc( will be automatically increased. cooldown_interval: float The cool down interval between two measurements - remote_args: Tuple - The argument for request_remote + remote_kwargs: dict + Passed to module_loader(). Ultimately, keyword args to request_remote(). enable_cpu_cache_flush: bool Whether to flush cache on CPU between repeated measurements. Flushing cache can make the measured latency of one operator closer to its actual latency during end-to-end inference. To make this option effective, the argument `number` should also be set to 1. This is only has effect on CPU task. + module_loader: ModuleLoader + A function that returns a ContextManager used to establish and teardown the remote session. """ if isinstance(build_result, MeasureResult): return build_result @@ -525,55 +548,38 @@ def run_through_rpc( errno = MeasureErrorNo.NO_ERROR try: # upload built module - remote = request_remote(*remote_args) - # Program the FPGA every single time when targeting VTA - if ( - hasattr(measure_input.target, "device_name") - and measure_input.target.device_name == "vta" - ): - # pylint: disable=import-outside-toplevel - from vta import program_fpga, reconfig_runtime - - program_fpga(remote, None) - reconfig_runtime(remote) - remote.upload(build_result.filename) - func = remote.load_module(os.path.split(build_result.filename)[1]) - ctx = remote.context(str(measure_input.target), 0) - - # Limitation: - # We can not get PackFunction directly in the remote mode as it is wrapped - # under the std::function. We could lift the restriction later once we fold - # the PackedFunc as an object. Currently, we pass function name to work - # around it. - f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else "" - time_f = func.time_evaluator( - func.entry_name, - ctx, - number=number, - repeat=repeat, - min_repeat_ms=min_repeat_ms, - f_preproc=f_prepare, - ) - - try: - random_fill = remote.get_function("tvm.contrib.random.random_fill") - except AttributeError: - raise AttributeError( - "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" + with module_loader(remote_kwargs, build_result) as (remote, mod): + ctx = remote.context(str(measure_input.target), 0) + + # Limitation: + # We can not get PackFunction directly in the remote mode as it is wrapped + # under the std::function. We could lift the restriction later once we fold + # the PackedFunc as an object. Currently, we pass function name to work + # around it. + f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else "" + time_f = mod.time_evaluator( + mod.entry_name, + ctx, + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms, + f_preproc=f_prepare, ) - args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info] - if "scatter" not in measure_input.task.name: - # the index tensor of scatter op cannot be randomly initialized - for arg in args: - random_fill(arg) - ctx.sync() - costs = time_f(*args).results + try: + random_fill = remote.get_function("tvm.contrib.random.random_fill") + except AttributeError: + raise AttributeError( + "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices" + ) + args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info] + if "scatter" not in measure_input.task.name: + # the index tensor of scatter op cannot be randomly initialized + for arg in args: + random_fill(arg) + ctx.sync() - # clean up remote files - remote.remove(build_result.filename) - remote.remove(os.path.splitext(build_result.filename)[0] + ".so") - remote.remove("") + costs = time_f(*args).results if len(costs) > 2: # remove largest and smallest value to reduce variance costs = list(costs) @@ -592,6 +598,40 @@ def run_through_rpc( return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp) +def default_module_loader(pre_load_function=None): + """Returns a default function that can be passed as module_loader to run_through_rpc. + + Parameters + ---------- + pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]] + Invoked after a session is established and before the default code-loading RPC calls are + issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment. + + Returns + ------- + ModuleLoader : + A function that can be passed as module_loader to run_through_rpc. + """ + + @contextlib.contextmanager + def default_module_loader_mgr(remote_kwargs, build_result): + remote = request_remote(**remote_kwargs) + if pre_load_function is not None: + pre_load_function(remote, build_result) + + remote.upload(build_result.filename) + try: + yield remote, remote.load_module(os.path.split(build_result.filename)[1]) + + finally: + # clean up remote files + remote.remove(build_result.filename) + remote.remove(os.path.splitext(build_result.filename)[0] + ".so") + remote.remove("") + + return default_module_loader_mgr + + def request_remote(device_key, host=None, port=None, priority=1, timeout=60): """Request a remote session diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index d143c4db6884..5fce76808c45 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -22,6 +22,7 @@ """ import sys +from .autotvm import module_loader from .bitstream import get_bitstream_path, download_bitstream from .environment import get_env, Environment from .rpc_client import reconfig_runtime, program_fpga diff --git a/vta/python/vta/autotvm.py b/vta/python/vta/autotvm.py new file mode 100644 index 000000000000..9aa7390f238f --- /dev/null +++ b/vta/python/vta/autotvm.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines AutoTVM components used with VTA.""" + +from tvm.autotvm.measure import default_module_loader +from . import rpc_client + + +def module_loader(bitstream=None): + """Construct a ModuleLoader implementation specialized for VTA. + + Parameters + ---------- + bitsream : Optional[str] + Path to the bitstream to write prior to uploading code. + + Returns + ------- + ModuleLoader : + The ModuleLoader instance. + """ + + def reprogram_fpga(remote, _build_result): + """default_module_loader callback which reprograms the FPGA. + + Parameters + ---------- + remote : tvm.rpc.RPCSession + RPC session established to the remote device. + + _build_result : tvm.autotvm.measure.measure_methods.BuildResult + Artifact from the build phase, unused here. + """ + rpc_client.program_bitstream(remote, bitstream) + rpc_client.reconfig_runtime(remote) + + return default_module_loader(reprogram_fpga) diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index c5885b65c0f3..ed2671c75ae8 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -215,6 +215,7 @@ def compile_network(env, target, model, start_pack, stop_pack): port=tracker_port, number=5, timeout=60, + module_loader=vta.module_loader(), # check_correctness=True, # TODO: re-enable when check_correctness works again. ), ),