Skip to content

Commit

Permalink
Introduce module_loader to AutoTVM. (apache#7337)
Browse files Browse the repository at this point in the history
* Introduce code_loader to AutoTVM.

 * Prepares for autotuning with microTVM, and provides extension hook
   for VTA.

* add vta hook

* git-black

* pylint

* Add missing import

* Fix import problem

* add missing import

* rename code_loader to module_loader

* rename remote_kw to remote_kwargs

* black format
  • Loading branch information
areusch authored and Lokiiiiii committed Mar 1, 2021
1 parent ac63b6c commit ad6a84b
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 50 deletions.
8 changes: 7 additions & 1 deletion python/tvm/autotvm/measure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
measure_option,
create_measure_batch,
)
from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
from .measure_methods import (
LocalBuilder,
LocalRunner,
RPCRunner,
default_module_loader,
request_remote,
)
from .executor import Executor
from .local_executor import LocalExecutor
138 changes: 89 additions & 49 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
remote devices, recording the running time costs, and checking the correctness of the output.
"""

import contextlib
import logging
import shutil
import os
import threading
import time
import typing
from random import getrandbits
from collections import namedtuple
import tempfile
Expand Down Expand Up @@ -199,6 +201,9 @@ class RPCRunner(Runner):
its actual latency during end-to-end inference.
To make this option effective, the argument `number` should also be set to 1.
This is only has effect on CPU task.
module_loader : ModuleLoader
If given, a context manager that loads the module to be timed into the remote runtime.
If not given, default_module_loader is used.
"""

def __init__(
Expand All @@ -214,6 +219,7 @@ def __init__(
min_repeat_ms=0,
cooldown_interval=0.1,
enable_cpu_cache_flush=False,
module_loader=None,
):
super(RPCRunner, self).__init__(timeout, n_parallel)

Expand All @@ -229,6 +235,7 @@ def __init__(

self.enable_cpu_cache_flush = enable_cpu_cache_flush
self.cooldown_interval = cooldown_interval
self.module_loader = module_loader

self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1))

Expand Down Expand Up @@ -280,6 +287,11 @@ def run(self, measure_inputs, build_results):
for measure_inp, build_res in zip(
measure_inputs[i : i + self.n_parallel], build_results[i : i + self.n_parallel]
):
module_loader = (
self.module_loader
if self.module_loader is not None
else default_module_loader()
)
ret = self.executor.submit(
run_through_rpc,
measure_inp,
Expand All @@ -290,6 +302,7 @@ def run(self, measure_inputs, build_results):
self.cooldown_interval,
remote_args,
self.enable_cpu_cache_flush,
module_loader,
)
futures.append(ret)

Expand Down Expand Up @@ -352,6 +365,7 @@ def __init__(
min_repeat_ms=0,
cooldown_interval=0.1,
enable_cpu_cache_flush=False,
module_loader=None,
):
super(LocalRunner, self).__init__(
"",
Expand All @@ -365,6 +379,7 @@ def __init__(
min_repeat_ms=min_repeat_ms,
cooldown_interval=cooldown_interval,
enable_cpu_cache_flush=enable_cpu_cache_flush,
module_loader=module_loader,
)
self.tracker = None
self.server = None
Expand Down Expand Up @@ -473,15 +488,21 @@ def __call__(self, measure_input, tmp_dir, **kwargs):
return BuildResult(filename, arg_info, None, time.time() - tic)


ModuleLoader = typing.Callable[
[dict, dict], typing.ContextManager[typing.Tuple[tvm.rpc.RPCSession, tvm.runtime.Module]]
]


def run_through_rpc(
measure_input,
build_result,
number,
repeat,
min_repeat_ms,
cooldown_interval,
remote_args,
remote_kwargs,
enable_cpu_cache_flush=False,
module_loader=None,
):
"""Run a generated library through rpc
Expand Down Expand Up @@ -509,14 +530,16 @@ def run_through_rpc(
will be automatically increased.
cooldown_interval: float
The cool down interval between two measurements
remote_args: Tuple
The argument for request_remote
remote_kwargs: dict
Passed to module_loader(). Ultimately, keyword args to request_remote().
enable_cpu_cache_flush: bool
Whether to flush cache on CPU between repeated measurements.
Flushing cache can make the measured latency of one operator closer to
its actual latency during end-to-end inference.
To make this option effective, the argument `number` should also be set to 1.
This is only has effect on CPU task.
module_loader: ModuleLoader
A function that returns a ContextManager used to establish and teardown the remote session.
"""
if isinstance(build_result, MeasureResult):
return build_result
Expand All @@ -525,55 +548,38 @@ def run_through_rpc(
errno = MeasureErrorNo.NO_ERROR
try:
# upload built module
remote = request_remote(*remote_args)
# Program the FPGA every single time when targeting VTA
if (
hasattr(measure_input.target, "device_name")
and measure_input.target.device_name == "vta"
):
# pylint: disable=import-outside-toplevel
from vta import program_fpga, reconfig_runtime

program_fpga(remote, None)
reconfig_runtime(remote)
remote.upload(build_result.filename)
func = remote.load_module(os.path.split(build_result.filename)[1])
ctx = remote.context(str(measure_input.target), 0)

# Limitation:
# We can not get PackFunction directly in the remote mode as it is wrapped
# under the std::function. We could lift the restriction later once we fold
# the PackedFunc as an object. Currently, we pass function name to work
# around it.
f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
time_f = func.time_evaluator(
func.entry_name,
ctx,
number=number,
repeat=repeat,
min_repeat_ms=min_repeat_ms,
f_preproc=f_prepare,
)

try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
with module_loader(remote_kwargs, build_result) as (remote, mod):
ctx = remote.context(str(measure_input.target), 0)

# Limitation:
# We can not get PackFunction directly in the remote mode as it is wrapped
# under the std::function. We could lift the restriction later once we fold
# the PackedFunc as an object. Currently, we pass function name to work
# around it.
f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
time_f = mod.time_evaluator(
mod.entry_name,
ctx,
number=number,
repeat=repeat,
min_repeat_ms=min_repeat_ms,
f_preproc=f_prepare,
)
args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
ctx.sync()

costs = time_f(*args).results
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
if "scatter" not in measure_input.task.name:
# the index tensor of scatter op cannot be randomly initialized
for arg in args:
random_fill(arg)
ctx.sync()

# clean up remote files
remote.remove(build_result.filename)
remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
remote.remove("")
costs = time_f(*args).results

if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
Expand All @@ -592,6 +598,40 @@ def run_through_rpc(
return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)


def default_module_loader(pre_load_function=None):
"""Returns a default function that can be passed as module_loader to run_through_rpc.
Parameters
----------
pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]]
Invoked after a session is established and before the default code-loading RPC calls are
issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment.
Returns
-------
ModuleLoader :
A function that can be passed as module_loader to run_through_rpc.
"""

@contextlib.contextmanager
def default_module_loader_mgr(remote_kwargs, build_result):
remote = request_remote(**remote_kwargs)
if pre_load_function is not None:
pre_load_function(remote, build_result)

remote.upload(build_result.filename)
try:
yield remote, remote.load_module(os.path.split(build_result.filename)[1])

finally:
# clean up remote files
remote.remove(build_result.filename)
remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
remote.remove("")

return default_module_loader_mgr


def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
"""Request a remote session
Expand Down
1 change: 1 addition & 0 deletions vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
import sys

from .autotvm import module_loader
from .bitstream import get_bitstream_path, download_bitstream
from .environment import get_env, Environment
from .rpc_client import reconfig_runtime, program_fpga
Expand Down
52 changes: 52 additions & 0 deletions vta/python/vta/autotvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Defines AutoTVM components used with VTA."""

from tvm.autotvm.measure import default_module_loader
from . import rpc_client


def module_loader(bitstream=None):
"""Construct a ModuleLoader implementation specialized for VTA.
Parameters
----------
bitsream : Optional[str]
Path to the bitstream to write prior to uploading code.
Returns
-------
ModuleLoader :
The ModuleLoader instance.
"""

def reprogram_fpga(remote, _build_result):
"""default_module_loader callback which reprograms the FPGA.
Parameters
----------
remote : tvm.rpc.RPCSession
RPC session established to the remote device.
_build_result : tvm.autotvm.measure.measure_methods.BuildResult
Artifact from the build phase, unused here.
"""
rpc_client.program_bitstream(remote, bitstream)
rpc_client.reconfig_runtime(remote)

return default_module_loader(reprogram_fpga)
1 change: 1 addition & 0 deletions vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def compile_network(env, target, model, start_pack, stop_pack):
port=tracker_port,
number=5,
timeout=60,
module_loader=vta.module_loader(),
# check_correctness=True, # TODO: re-enable when check_correctness works again.
),
),
Expand Down

0 comments on commit ad6a84b

Please sign in to comment.