Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoSchedule] Sparse dense tuning support with custom sketch rule #7313

Merged
merged 36 commits into from
Mar 6, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/measure_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace tvm {
namespace auto_scheduler {

const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*)

/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_

#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/target/target.h>

namespace tvm {
Expand Down Expand Up @@ -120,6 +121,8 @@ class SearchTaskNode : public Object {
HardwareParams hardware_params;
/*! \brief The layout rewrite option used for measuring programs. */
LayoutRewriteOption layout_rewrite_option;
/*! \brief Names of some user defined input data used in program measuring. */
Array<String> task_inputs;
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("compute_dag", &compute_dag);
Expand All @@ -128,6 +131,7 @@ class SearchTaskNode : public Object {
v->Visit("target_host", &target_host);
v->Visit("hardware_params", &hardware_params);
v->Visit("layout_rewrite_option", &layout_rewrite_option);
v->Visit("task_inputs", &task_inputs);
}

static constexpr const char* _type_key = "auto_scheduler.SearchTask";
Expand All @@ -148,9 +152,11 @@ class SearchTask : public ObjectRef {
* \param target_host The target host device of this search task.
* \param hardware_params Hardware parameters used in this search task.
* \param layout_rewrite_option The layout rewrite option used for measuring programs.
* \param task_inputs Names of some user defined input data used in program measuring.
*/
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option,
Array<String> task_inputs);

TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
};
Expand Down
100 changes: 87 additions & 13 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def recover_measure_input(inp, rebuild_state=False):
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_inputs),
)

if rebuild_state:
Expand Down Expand Up @@ -719,6 +720,45 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
return results


def _prepare_input_map(args):
"""This function deals with special task inputs.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
args : List[Tensor]
Input/output Tensor of a TVM subgraph.

Returns
-------
A Dict[Tensor, str] that maps the input Tensor to a buffer name.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

Note
----
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
The buffer name is specially designed, and these buffer should be provided in
`SearchTask(..., task_inputs={...})`.
"""
# pylint: disable=import-outside-toplevel
from tvm import topi # lazily import to avoid recursive dependency

# A dict that maps the input tensor arg to a buffer name
tensor_input_map = {}

# Case 0: Check placeholder name
for arg in args:
if isinstance(arg.op, tvm.te.PlaceholderOp):
if arg.op.name != "placeholder":
tensor_input_map[arg] = arg.op.name

# Case 1: Check sparse op
sparse_input_map = topi.nn.sparse.try_get_sparse_input(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I asked this before, but can we have a more general mechanism than checking only for sparse. There are other use cases that require specific input (sorting, scatter).

Copy link
Contributor Author

@jcf94 jcf94 Mar 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've also had some discussions in our weekly sync while didn't figure out any better solutions.
There're several reasons:

  1. Different ops have different requirements over specific inputs;
  2. While the problems is in a subgraph generated in Relay Integration, the placeholder are all the same, we can not differentiate them from tag, name or any other way, even the order of inputs are not guaranteed.

Current approach is to merge all specific inputs checking to this function, at least they have a same entry here. For the other ops, you have to add their own check functions below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, my colleague is going to add Ansor support for sparse_conv2d. We'll add extra check to this entry first, and see if there's any better way to merge them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we associate the lookup mechanism with @register_workload? It would at least be extensible then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we associate the lookup mechanism with @register_workload? It would at least be extensible then.

Thanks! This is a pretty good idea, I'll have a try.

tensor_input_map.update(sparse_input_map)

# Case 2: Check ...
# Process any other special buffers here and update them to tensor_input_map

return tensor_input_map


def _timed_eval_func(
inp_serialized,
build_res,
Expand All @@ -729,7 +769,11 @@ def _timed_eval_func(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_inputs = inp.task.task_inputs
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -758,11 +802,25 @@ def _timed_eval_func(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
for arg in args:
random_fill(arg)

tensor_input_map = _prepare_input_map(build_res.args) if task_inputs else {}
args = []
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_inputs:
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with SearchTask.AddTaskInput()"
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
ctx.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
Expand Down Expand Up @@ -911,7 +969,11 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_inputs = inp.task.task_inputs
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -943,18 +1005,30 @@ def _timed_rpc_run(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
for arg in args:
random_fill(arg)
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = _prepare_input_map(build_res.args) if task_inputs else {}
args = []
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_inputs:
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with SearchTask.AddTaskInput()"
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
ctx.sync()

costs = time_f(*args).results

# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
Expand Down
Loading