Skip to content

Commit

Permalink
[Frontend][Tensorflow] Sparse_Dense Op CSR scheduling issue resolved …
Browse files Browse the repository at this point in the history
…for Cuda & X86 (apache#7148)

* [Frontend][Tensorflow] Sparse_Dense Op CSR scheduling issue resolved for both cuda & x86

* [1] Review comments handled

* [2] Review comments handled

* [3] Review comments handled
  • Loading branch information
ANSHUMAN TRIPATHY authored and trevor-m committed Jan 21, 2021
1 parent e12d490 commit a2af9e1
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 39 deletions.
45 changes: 38 additions & 7 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from tvm import relay, te

from .. import nn
from ..utils import traverse_inline
from ..utils import traverse_inline, get_const_tuple, prod, get_const_int


def sparse_dense(data, weight_data, weight_indices, weight_indptr):
def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -57,19 +57,21 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
2-D with shape [M, N]
"""
# pylint:disable=unused-argument
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr)
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs)


def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])

# TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also
def _callback(op):
if op.tag == "sparse_dense_bsrmm_v2":
if op.tag == "sparse_dense_sp_rhs_bsrmm" or op.tag == "sparse_dense_sp_lhs_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2"
assert (
y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block"
or y_bsrmm.op.tag == "sparse_dense_sp_lhs_bsrmm_block"
)
out = s.outputs[0].output(0)

if op not in s.outputs:
Expand All @@ -91,6 +93,13 @@ def _callback(op):
s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
s[out].set_store_predicate(thread_x.var.equal(0))
elif op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_rhs_csrmm":
out = op.output(0)
const_size = get_const_int(prod(out.shape))
fused = s[out].fuse(*s[out].op.axis)
bx, tx = s[out].split(fused, factor=const_size)
s[out].bind(tx, te.thread_axis("threadIdx.x"))
s[out].bind(bx, te.thread_axis("blockIdx.x"))

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down Expand Up @@ -279,7 +288,26 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
return out


def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr):
def is_valid_for_sparse_dense_padded(data, weight_data):
"""
Check whether input is applicable for sparse_dense_padded op.
If not we should fall back to default scheduling.
"""
# pylint:disable=invalid-name
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
m = get_const_tuple(data.checked_type.shape)[1]
if len(weight_data.shape) == 1:
bs_m = 1
else:
bs_m = weight_data.shape[1]

mb = m // bs_m
if mb >= warp_size:
return True
return False


def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -311,6 +339,8 @@ def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr):
output : tvm.te.Tensor
2-D with shape [M, N]
"""
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
assert not sparse_lhs, "Currently only sparse weight is supported."
return sparse_dense_tir(data, weight_data, weight_indices, weight_indptr)


Expand Down Expand Up @@ -368,6 +398,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
isinstance(inputs[1], relay.Constant)
and isinstance(inputs[2], relay.Constant)
and isinstance(inputs[3], relay.Constant)
and is_valid_for_sparse_dense_padded(inputs[0], inputs[1].data.asnumpy())
):
if len(inputs[1].data.asnumpy().shape) == 1:
sparse_matrix = sp.csr_matrix(
Expand Down
36 changes: 18 additions & 18 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import get_const_tuple


def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -52,13 +52,13 @@ def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm_v2
func = _sparse_dense_sp_rhs_csrmm
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm_v2
func = _sparse_dense_sp_rhs_bsrmm
return func(data, weight_data, weight_indices, weight_indptr)


def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight):
"""
Computes sparse-dense matrix multiplication of
`(data_data, data_indices, data_indptr)` and `weight.T`
Expand Down Expand Up @@ -87,9 +87,9 @@ def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
"""
assert len(data_data.shape) in (1, 3)
if len(data_data.shape) == 1:
func = _sparse_dense_csrmm_v1
func = _sparse_dense_sp_lhs_csrmm
if len(data_data.shape) == 3:
func = _sparse_dense_bsrmm_v1
func = _sparse_dense_sp_lhs_bsrmm
return func(data_data, data_indices, data_indptr, weight)


Expand Down Expand Up @@ -128,12 +128,12 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_
2-D with shape [M, N]
"""
if sparse_lhs:
return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data)
return sparse_dense_sp_lhs(sparse_data, sparse_indices, sparse_indptr, dense_data)
else:
return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr)
return sparse_dense_sp_rhs(dense_data, sparse_data, sparse_indices, sparse_indptr)


def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight):
def _sparse_dense_sp_lhs_csrmm(data_data, data_indices, data_indptr, weight):
oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0])

def f(row, i):
Expand All @@ -146,10 +146,10 @@ def f(row, i):
weight_val = weight[i, data_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm_v1")
return te.compute(oshape, f, tag="sparse_dense_sp_lhs_csrmm")


def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr):
def _sparse_dense_sp_rhs_csrmm(data, weight_data, weight_indices, weight_indptr):
oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1)

def f(i, row):
Expand All @@ -162,10 +162,10 @@ def f(i, row):
weight_val = data[i, weight_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm_v2")
return te.compute(oshape, f, tag="sparse_dense_sp_rhs_csrmm")


def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight):
def _sparse_dense_sp_lhs_bsrmm(data_data, data_indices, data_indptr, weight):
(m, _) = get_const_tuple(weight.shape)
(_, bs_r, bs_c) = get_const_tuple(data_data.shape)
(num_blocks_plus_1,) = get_const_tuple(data_indptr.shape)
Expand All @@ -187,16 +187,16 @@ def _compute_block(nb_j, j, i):
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1"
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_sp_lhs_bsrmm_block"
)
return te.compute(
(num_blocks * bs_r, m),
lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n],
tag="sparse_dense_bsrmm_v1",
tag="sparse_dense_sp_lhs_bsrmm",
)


def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr):
def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
Expand All @@ -218,12 +218,12 @@ def _compute_block(i, nb_j, j):
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2"
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block"
)
return te.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
tag="sparse_dense_bsrmm_v2",
tag="sparse_dense_sp_rhs_bsrmm",
)


Expand Down
18 changes: 10 additions & 8 deletions python/tvm/topi/x86/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ def schedule_sparse_dense(outs):

def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
(_, v_i) = s[op].op.axis
s[op].vectorize(v_i)
(y_o, y_i) = s[outs[0].op].split(s[outs[0].op].op.axis[1], 2 * simd_width)
s[op].compute_at(s[outs[0]], y_o)
s[outs[0].op].vectorize(y_i)
if op.tag == "sparse_dense_bsrmm":
if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm":
(y_o, y_i) = s[op].split(s[op].op.axis[1], 2)
fused = s[op].fuse(s[op].op.axis[0], y_o)
s[op].parallel(fused)
s[op].vectorize(y_i)
elif op.tag == "sparse_dense_sp_rhs_bsrmm" or op.tag == "sparse_dense_sp_rhs_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
assert (
y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block"
or y_bsrmm.op.tag == "sparse_dense_sp_lhs_bsrmm_block"
)
y_reshape = op
(m, num_blocks, b_r) = s[y_bsrmm].op.axis
bs_r = get_const_int(b_r.dom.extent)
Expand Down
3 changes: 1 addition & 2 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,8 +1794,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal

B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)

# TODO(ANSHUMAN87): There is an issue in cuda scheduling for csr, work in progress
compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)
compare_tf_with_tvm([B_np], [B.name], result.name)


def test_forward_sparse_dense_matmul():
Expand Down
13 changes: 9 additions & 4 deletions tests/python/topi/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,19 +507,24 @@ def test_sparse_dense_padded_alter_op():
K = 128
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = random_bsr_matrix(N, K, 2, 2, density=0.01, dtype="float32")
x = relay.var("x", relay.TensorType(X_np.shape, "float32"))
mult = relay.op.nn.sparse_dense(
relay.Constant(tvm.nd.array(X_np)),
x,
(
relay.Constant(tvm.nd.array(W_sp_np.data)),
relay.Constant(tvm.nd.array(W_sp_np.indices)),
relay.Constant(tvm.nd.array(W_sp_np.indptr)),
),
)
f = relay.Function([], mult)
f = relay.transform.InferType()(tvm.IRModule.from_expr(f))
f_ = relay.transform.AlterOpLayout()(f)
f = relay.Function([x], mult)
f_ = relay.transform.InferType()(tvm.IRModule.from_expr(f))
f_ = relay.transform.AlterOpLayout()(f_)
assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded"

# build with cuda and AlterOpLayout to ensure that sparse_dense_padded is in action
with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"):
x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda"))


if __name__ == "__main__":
test_csrmv()
Expand Down

0 comments on commit a2af9e1

Please sign in to comment.