Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoSchedule] Sparse dense tuning support with custom sketch rule #7313

Merged
merged 36 commits into from
Mar 6, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 170 additions & 7 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@
import tempfile
import multiprocessing

import numpy as np

import tvm._ffi
from tvm.runtime import Object, module, ndarray
from tvm.driver import build_module
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.te import PlaceholderOp, ComputeOp

from . import _ffi_api
from .loop_state import StateObject
Expand Down Expand Up @@ -719,6 +722,87 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
return results


def _process_sparse_input(args):
sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None

def _process_inputs(input_tensors, M, N, prefix_init):
nonlocal sparse_prefix
nonlocal sparse_data
nonlocal sparse_indices
nonlocal sparse_indptr

assert len(input_tensors) == 4
unsure_tensors = list(input_tensors)
# Get the Dense data
dense_data = None
for tensor in unsure_tensors:
if len(tensor.shape) == 2:
assert dense_data is None
dense_data = tensor
assert M == dense_data.shape[0]
K = dense_data.shape[1]
unsure_tensors.remove(dense_data)

# Get the Sparse data
sparse_data = None
for tensor in unsure_tensors:
if len(tensor.shape) == 3:
assert sparse_data is None
sparse_data = tensor
block_size, BS_R, BS_C = sparse_data.shape
unsure_tensors.remove(sparse_data)

# Get the Sparse indptr & indices
sparse_indices = None
for tensor in unsure_tensors:
assert len(tensor.shape) == 1
if tensor.shape[0] == block_size:
assert sparse_indices is None
sparse_indices = tensor
unsure_tensors.remove(sparse_indices)
assert len(unsure_tensors) == 1
sparse_indptr = unsure_tensors[0]

# Generate the sparse_prefix
density = 1.0
for i in sparse_data.shape:
density *= i
density /= (K * N)
density = density.value
sparse_prefix = "%s_%d_%d_%d_%d_%d_%.2f_" % (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could run into the case that two matrices have the same sparse_prefix, but different non-zero structure. Will this cause issues? What if one of the matrices has one nonzero per row and the other has one dense row (while maintaining the same sparsity)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though in my test a schedule seems to have similar performance with different random sparse data, I think that may still be a potential problem. Unfortunately, I have not figured out any better solution.

Copy link
Contributor

@tkonolige tkonolige Jan 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could hash the indptr and indices arrays as these determine the structure. Alternatively you could hash the number of nonzeros per row.

It would be interesting to study if tuning performs the same independent of structure (but for the same sparsity).

prefix_init, M, N, K, BS_R, BS_C, density
)

visited = set()
def _traverse(t):
# We cannot directly add tensors to the set, because the comparison of
# two tensors with ndim=0 is ambiguous.
assert t.handle is not None
if t.handle.value in visited:
return
if isinstance(t.op, ComputeOp):
# TODO(jcf94): Currently only support to tune one sparse op
if t.op.tag == "sparse_dense_sp_rhs_bsrmm":
M, N = t.shape
assert len(t.op.input_tensors) == 1
block_tensor = t.op.input_tensors[0]
_process_inputs(block_tensor.op.input_tensors, M, N, "sparse_dense_bsr")
if t.op.tag == "sparse_conv2d_bsrmm":
N, OH = t.shape[0], t.shape[1]
assert len(t.op.input_tensors) == 1
block_tensor = t.op.input_tensors[0]
_process_inputs(block_tensor.op.input_tensors, N, OH, "sparse_dense_bsr")
if sparse_prefix is not None:
return
for x in t.op.input_tensors:
_traverse(x)
visited.add(t.handle.value)

for arg in args:
_traverse(arg)

return sparse_prefix, sparse_data, sparse_indices, sparse_indptr

def _timed_eval_func(
inp_serialized,
build_res,
Expand Down Expand Up @@ -758,11 +842,31 @@ def _timed_eval_func(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
for arg in args:
random_fill(arg)

# Check sparse op
sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \
_process_sparse_input(build_res.args)
if sparse_prefix:
args = []
for arg in build_res.args:
if arg == sparse_data:
args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_data"), ctx))
elif arg == sparse_indices:
args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indices"), ctx))
elif arg == sparse_indptr:
args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indptr"), ctx))
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
else:
args = [
ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args
]
for arg in args:
random_fill(arg)
ctx.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
Expand Down Expand Up @@ -943,18 +1047,36 @@ def _timed_rpc_run(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
)
for arg in args:
random_fill(arg)
ctx.sync()

# Check sparse op
sparse_prefix, sparse_data, sparse_indices, sparse_indptr = \
_process_sparse_input(build_res.args)
if sparse_prefix:
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
args = []
for arg in build_res.args:
if arg == sparse_data:
args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_data"), ctx))
elif arg == sparse_indices:
args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indices"), ctx))
elif arg == sparse_indptr:
args.append(ndarray.array(get_special_buffer(sparse_prefix+"W_indptr"), ctx))
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
else:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
for arg in args:
random_fill(arg)
ctx.sync()
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
costs = time_f(*args).results

# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
Expand Down Expand Up @@ -1132,3 +1254,44 @@ def rpc_runner_run(
print("")

return results


# The map stores special registered buffer for measurement
# This can be used for sparse workloads when we cannot use random tensors for measurment.
global special_buffer_table
special_buffer_table = {}
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

def register_special_buffer(tensor_name, data):
"""Register special buffer for measurement
This can be used for sparse workloads when we cannot use random tensors for measurment.
"""
if tensor_name in special_buffer_table.keys():
return True

if os.path.isfile(tensor_name):
print("Load ", tensor_name)
if tensor_name.startswith("sparse_dense_bsr"):
if tensor_name.endswith("data"):
data = np.fromfile(tensor_name, dtype="float32", sep=" ")
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
name_split = tensor_name.split("_")
BS_R = int(name_split[6])
BS_C = int(name_split[7])
data = data.reshape((data.shape[0] // BS_R // BS_C, BS_R, BS_C))
else:
data = np.fromfile(tensor_name, dtype="int32", sep=" ")
elif data is None:
return False

special_buffer_table[tensor_name] = data

if not os.path.isfile(tensor_name):
data.tofile(tensor_name, " ")

return True

def get_special_buffer(tensor_name):
"""Get special buffer for measurement.
This can be used for sparse workloads when we cannot use random tensors for measurment.
The buffers are registered by `register_special_buffer`.
"""
return special_buffer_table.get(tensor_name, None)
6 changes: 4 additions & 2 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _compute_block(nb_j, j, i):


def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(m, k) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
num_blocks = num_blocks_plus_1 - 1
Expand All @@ -218,7 +218,9 @@ def _compute_block(i, nb_j, j):
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block"
(m, num_blocks, bs_r), _compute_block,
tag="sparse_dense_sp_rhs_bsrmm_block",
attrs={"FLOP": 2 * m * num_blocks * bs_r * k},
)
return te.compute(
(m, num_blocks * bs_r),
Expand Down
16 changes: 16 additions & 0 deletions src/auto_scheduler/search_policy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,22 @@ const std::vector<int>& SplitFactorizationMemo::GetFactors(int n) {

/********** Utils interface API for ffi **********/

TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsGetConsumers")
.set_body_typed([](const SearchTask& task, const State& state, int stage_id) {
const std::set<int>& consumers = GetConsumers(task, state, stage_id);
tvm::Map<IntImm, IntImm> ret;
for (const auto& i : consumers) {
ret.Set(Integer(i), Integer(i));
}
return ret;
});

TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsElementwiseMatch")
.set_body_typed([](const SearchTask& task, const State& state, int stage_id,
int target_stage_id) {
return ElementwiseMatch(task, state, stage_id, target_stage_id);
});

TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyUtilsIsTiled")
.set_body_typed([](const Stage& stage) { return IsTiled(stage); });

Expand Down
Loading