From 76f8cc8b5e7b575edb61e14d992c5601aeed09d2 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Tue, 22 Dec 2020 11:43:01 +0530 Subject: [PATCH 1/4] [Frontend][Tensorflow] Sparse_Dense Op CSR scheduling issue resolved for both cuda & x86 --- python/tvm/topi/cuda/sparse.py | 45 ++++++++++++++++--- python/tvm/topi/x86/sparse.py | 18 ++++---- .../frontend/tensorflow/test_forward.py | 3 +- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index c59e6887d47e..c461d8db7a9d 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -23,10 +23,10 @@ from tvm import relay, te from .. import nn -from ..utils import traverse_inline +from ..utils import traverse_inline, get_const_tuple, prod, get_const_int -def sparse_dense(data, weight_data, weight_indices, weight_indptr): +def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False): """ Computes sparse-dense matrix multiplication of `data` and `(weight_data, weight_indices, weight_indptr).T` @@ -57,7 +57,7 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr): 2-D with shape [M, N] """ # pylint:disable=unused-argument - return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr) + return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs) def schedule_sparse_dense(outs): @@ -65,11 +65,13 @@ def schedule_sparse_dense(outs): # pylint:disable=invalid-name s = te.create_schedule([x.op for x in outs]) - # TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also def _callback(op): - if op.tag == "sparse_dense_bsrmm_v2": + if op.tag == "sparse_dense_bsrmm_v2" or op.tag == "sparse_dense_bsrmm_v1": y_bsrmm = op.input_tensors[0] - assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2" + assert ( + y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2" + or y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v1" + ) out = s.outputs[0].output(0) if op not in s.outputs: @@ -91,6 +93,13 @@ def _callback(op): s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx) s[y_bsrmm].set_store_predicate(thread_x.var.equal(0)) s[out].set_store_predicate(thread_x.var.equal(0)) + elif op.tag == "sparse_dense_csrmm_v2" or op.tag == "sparse_dense_csrmm_v1": + out = op.output(0) + const_size = get_const_int(prod(out.shape)) + fused = s[out].fuse(*s[out].op.axis) + bx, tx = s[out].split(fused, factor=const_size) + s[out].bind(tx, te.thread_axis("threadIdx.x")) + s[out].bind(bx, te.thread_axis("blockIdx.x")) traverse_inline(s, outs[0].op, _callback) return s @@ -279,7 +288,26 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): return out -def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr): +def is_valid_for_sparse_dense_padded(data, weight_data): + """ + Check whether input is applicable for sparse_dense_padded op. + If not we should fall back to default scheduling. + """ + # pylint:disable=invalid-name + warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) + m = get_const_tuple(data.checked_type.shape)[1] + if len(weight_data.shape) == 1: + bs_m = 1 + else: + bs_m = weight_data.shape[1] + + mb = m // bs_m + if mb >= warp_size: + return True + return False + + +def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False): """ Computes sparse-dense matrix multiplication of `data` and `(weight_data, weight_indices, weight_indptr).T` @@ -311,6 +339,8 @@ def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr): output : tvm.te.Tensor 2-D with shape [M, N] """ + # TODO(ANSHUMAN87): Handle for sparse_lhs case too + assert not sparse_lhs return sparse_dense_tir(data, weight_data, weight_indices, weight_indptr) @@ -368,6 +398,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type): isinstance(inputs[1], relay.Constant) and isinstance(inputs[2], relay.Constant) and isinstance(inputs[3], relay.Constant) + and is_valid_for_sparse_dense_padded(inputs[0], inputs[1].data.asnumpy()) ): if len(inputs[1].data.asnumpy().shape) == 1: sparse_matrix = sp.csr_matrix( diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index b6291083c8c1..8b4972e67c4a 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -28,15 +28,17 @@ def schedule_sparse_dense(outs): def _callback(op): simd_width = get_fp32_len() - if op.tag == "sparse_dense_csrmm" and op != outs[0].op: - (_, v_i) = s[op].op.axis - s[op].vectorize(v_i) - (y_o, y_i) = s[outs[0].op].split(s[outs[0].op].op.axis[1], 2 * simd_width) - s[op].compute_at(s[outs[0]], y_o) - s[outs[0].op].vectorize(y_i) - if op.tag == "sparse_dense_bsrmm": + if op.tag == "sparse_dense_csrmm_v2" or op.tag == "sparse_dense_csrmm_v1": + (y_o, y_i) = s[op].split(s[op].op.axis[1], 2) + fused = s[op].fuse(s[op].op.axis[0], y_o) + s[op].parallel(fused) + s[op].vectorize(y_i) + elif op.tag == "sparse_dense_bsrmm_v2" or op.tag == "sparse_dense_bsrmm_v1": y_bsrmm = op.input_tensors[0] - assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block" + assert ( + y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2" + or y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v1" + ) y_reshape = op (m, num_blocks, b_r) = s[y_bsrmm].op.axis bs_r = get_const_int(b_r.dom.extent) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 4bd3b919bc74..676203acbdcb 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1776,8 +1776,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) - # TODO(ANSHUMAN87): There is an issue in cuda scheduling for csr, work in progress - compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True) + compare_tf_with_tvm([B_np], [B.name], result.name) def test_forward_sparse_dense_matmul(): From af0a5490370f5577324c61a6b8ab945eb866f5d4 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 24 Dec 2020 11:22:41 +0530 Subject: [PATCH 2/4] [1] Review comments handled --- python/tvm/topi/cuda/sparse.py | 10 +++++----- python/tvm/topi/nn/sparse.py | 36 +++++++++++++++++----------------- python/tvm/topi/x86/sparse.py | 8 ++++---- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index c461d8db7a9d..f2cecacbc618 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -66,11 +66,11 @@ def schedule_sparse_dense(outs): s = te.create_schedule([x.op for x in outs]) def _callback(op): - if op.tag == "sparse_dense_bsrmm_v2" or op.tag == "sparse_dense_bsrmm_v1": + if op.tag == "sparse_dense_sp_rhs_bsrmm" or op.tag == "sparse_dense_sp_lhs_bsrmm": y_bsrmm = op.input_tensors[0] assert ( - y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2" - or y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v1" + y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block" + or y_bsrmm.op.tag == "sparse_dense_sp_lhs_bsrmm_block" ) out = s.outputs[0].output(0) @@ -93,7 +93,7 @@ def _callback(op): s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx) s[y_bsrmm].set_store_predicate(thread_x.var.equal(0)) s[out].set_store_predicate(thread_x.var.equal(0)) - elif op.tag == "sparse_dense_csrmm_v2" or op.tag == "sparse_dense_csrmm_v1": + elif op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_rhs_csrmm": out = op.output(0) const_size = get_const_int(prod(out.shape)) fused = s[out].fuse(*s[out].op.axis) @@ -340,7 +340,7 @@ def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr, sparse 2-D with shape [M, N] """ # TODO(ANSHUMAN87): Handle for sparse_lhs case too - assert not sparse_lhs + assert not sparse_lhs, "Currently only sparse weight is supported." return sparse_dense_tir(data, weight_data, weight_indices, weight_indptr) diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 94d6d9a16330..cdccc80bb5f8 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -23,7 +23,7 @@ from ..utils import get_const_tuple -def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr): +def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr): """ Computes sparse-dense matrix multiplication of `data` and `(weight_data, weight_indices, weight_indptr).T` @@ -52,13 +52,13 @@ def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr): """ assert len(weight_data.shape) in (1, 3) if len(weight_data.shape) == 1: - func = _sparse_dense_csrmm_v2 + func = _sparse_dense_sp_rhs_csrmm if len(weight_data.shape) == 3: - func = _sparse_dense_bsrmm_v2 + func = _sparse_dense_sp_rhs_bsrmm return func(data, weight_data, weight_indices, weight_indptr) -def sparse_dense_v1(data_data, data_indices, data_indptr, weight): +def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight): """ Computes sparse-dense matrix multiplication of `(data_data, data_indices, data_indptr)` and `weight.T` @@ -87,9 +87,9 @@ def sparse_dense_v1(data_data, data_indices, data_indptr, weight): """ assert len(data_data.shape) in (1, 3) if len(data_data.shape) == 1: - func = _sparse_dense_csrmm_v1 + func = _sparse_dense_sp_lhs_csrmm if len(data_data.shape) == 3: - func = _sparse_dense_bsrmm_v1 + func = _sparse_dense_sp_lhs_bsrmm return func(data_data, data_indices, data_indptr, weight) @@ -128,12 +128,12 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_ 2-D with shape [M, N] """ if sparse_lhs: - return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data) + return sparse_dense_sp_lhs(sparse_data, sparse_indices, sparse_indptr, dense_data) else: - return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr) + return sparse_dense_sp_rhs(dense_data, sparse_data, sparse_indices, sparse_indptr) -def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight): +def _sparse_dense_sp_lhs_csrmm(data_data, data_indices, data_indptr, weight): oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0]) def f(row, i): @@ -146,10 +146,10 @@ def f(row, i): weight_val = weight[i, data_indices[elem]] return te.sum(a_val * weight_val, axis=elem_idx) - return te.compute(oshape, f, tag="sparse_dense_csrmm_v1") + return te.compute(oshape, f, tag="sparse_dense_sp_lhs_csrmm") -def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr): +def _sparse_dense_sp_rhs_csrmm(data, weight_data, weight_indices, weight_indptr): oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1) def f(i, row): @@ -162,10 +162,10 @@ def f(i, row): weight_val = data[i, weight_indices[elem]] return te.sum(a_val * weight_val, axis=elem_idx) - return te.compute(oshape, f, tag="sparse_dense_csrmm_v2") + return te.compute(oshape, f, tag="sparse_dense_sp_rhs_csrmm") -def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight): +def _sparse_dense_sp_lhs_bsrmm(data_data, data_indices, data_indptr, weight): (m, _) = get_const_tuple(weight.shape) (_, bs_r, bs_c) = get_const_tuple(data_data.shape) (num_blocks_plus_1,) = get_const_tuple(data_indptr.shape) @@ -187,16 +187,16 @@ def _compute_block(nb_j, j, i): idxm = tvm.tir.indexmod bsrmm_block = te.compute( - (num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1" + (num_blocks, bs_r, m), _compute_block, tag="sparse_dense_sp_lhs_bsrmm_block" ) return te.compute( (num_blocks * bs_r, m), lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n], - tag="sparse_dense_bsrmm_v1", + tag="sparse_dense_sp_lhs_bsrmm", ) -def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr): +def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr): (m, _) = get_const_tuple(data.shape) (_, bs_r, bs_c) = get_const_tuple(weight_data.shape) (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) @@ -218,12 +218,12 @@ def _compute_block(i, nb_j, j): idxm = tvm.tir.indexmod bsrmm_block = te.compute( - (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2" + (m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block" ) return te.compute( (m, num_blocks * bs_r), lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)], - tag="sparse_dense_bsrmm_v2", + tag="sparse_dense_sp_rhs_bsrmm", ) diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index 8b4972e67c4a..c6300f6701e0 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -28,16 +28,16 @@ def schedule_sparse_dense(outs): def _callback(op): simd_width = get_fp32_len() - if op.tag == "sparse_dense_csrmm_v2" or op.tag == "sparse_dense_csrmm_v1": + if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm": (y_o, y_i) = s[op].split(s[op].op.axis[1], 2) fused = s[op].fuse(s[op].op.axis[0], y_o) s[op].parallel(fused) s[op].vectorize(y_i) - elif op.tag == "sparse_dense_bsrmm_v2" or op.tag == "sparse_dense_bsrmm_v1": + elif op.tag == "sparse_dense_sp_rhs_bsrmm" or op.tag == "sparse_dense_sp_rhs_bsrmm": y_bsrmm = op.input_tensors[0] assert ( - y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2" - or y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v1" + y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block" + or y_bsrmm.op.tag == "sparse_dense_sp_lhs_bsrmm_block" ) y_reshape = op (m, num_blocks, b_r) = s[y_bsrmm].op.axis From c078d0b2bf12fa969472eaeec37aa36b064f6ba1 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 14 Jan 2021 21:34:32 +0530 Subject: [PATCH 3/4] [2] Review comments handled --- tests/python/topi/python/test_topi_sparse.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index e47bfddbf7fc..dea318c5b1df 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -521,6 +521,33 @@ def test_sparse_dense_padded_alter_op(): assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded" +@tvm.testing.requires_cuda +def test_sparse_dense_padded_alter_op_var_inp(): + with tvm.target.Target("cuda"): + M = 128 + N = 16 + K = 128 + X_np = np.random.randn(M, K).astype("float32") + W_sp_np = random_bsr_matrix(N, K, 2, 2, density=0.01, dtype="float32") + x = relay.var("x", relay.TensorType(X_np.shape, "float32")) + mult = relay.op.nn.sparse_dense( + x, + ( + relay.Constant(tvm.nd.array(W_sp_np.data)), + relay.Constant(tvm.nd.array(W_sp_np.indices)), + relay.Constant(tvm.nd.array(W_sp_np.indptr)), + ), + ) + f = relay.Function([x], mult) + f_ = relay.transform.InferType()(tvm.IRModule.from_expr(f)) + f_ = relay.transform.AlterOpLayout()(f_) + assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded" + + # build with cuda and AlterOpLayout to ensure that sparse_dense_padded is in action + with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"): + x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda")) + + if __name__ == "__main__": test_csrmv() test_csrmm() @@ -530,5 +557,6 @@ def test_sparse_dense_padded_alter_op(): test_sparse_transpose_csr() test_sparse_dense_padded_cuda() test_sparse_dense_padded_alter_op() + test_sparse_dense_padded_alter_op_var_inp() test_sparse_dense_csr_reverse() test_sparse_dense_bsr_reverse() From 3829c5bedac94de5cfe3c44de9d02421566fd53e Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 14 Jan 2021 22:06:55 +0530 Subject: [PATCH 4/4] [3] Review comments handled --- tests/python/topi/python/test_topi_sparse.py | 23 -------------------- 1 file changed, 23 deletions(-) diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index dea318c5b1df..d5bd7aa1a21e 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -501,28 +501,6 @@ def test_sparse_dense_padded_cuda(): @tvm.testing.requires_cuda def test_sparse_dense_padded_alter_op(): - with tvm.target.Target("cuda"): - M = 128 - N = 16 - K = 128 - X_np = np.random.randn(M, K).astype("float32") - W_sp_np = random_bsr_matrix(N, K, 2, 2, density=0.01, dtype="float32") - mult = relay.op.nn.sparse_dense( - relay.Constant(tvm.nd.array(X_np)), - ( - relay.Constant(tvm.nd.array(W_sp_np.data)), - relay.Constant(tvm.nd.array(W_sp_np.indices)), - relay.Constant(tvm.nd.array(W_sp_np.indptr)), - ), - ) - f = relay.Function([], mult) - f = relay.transform.InferType()(tvm.IRModule.from_expr(f)) - f_ = relay.transform.AlterOpLayout()(f) - assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded" - - -@tvm.testing.requires_cuda -def test_sparse_dense_padded_alter_op_var_inp(): with tvm.target.Target("cuda"): M = 128 N = 16 @@ -557,6 +535,5 @@ def test_sparse_dense_padded_alter_op_var_inp(): test_sparse_transpose_csr() test_sparse_dense_padded_cuda() test_sparse_dense_padded_alter_op() - test_sparse_dense_padded_alter_op_var_inp() test_sparse_dense_csr_reverse() test_sparse_dense_bsr_reverse()