Skip to content

Commit

Permalink
[Relay][Frontend] SparseTensorDenseMatMul support for Tensorflow (apa…
Browse files Browse the repository at this point in the history
…che#6685)

* [Relay][Frontend] SparseTensorDenseMatMul support for Tensorflow

* Lint error resolved

* [1] Review comments handled

* [2] Review comments handled
  • Loading branch information
ANSHUMAN TRIPATHY authored and Trevor Morris committed Dec 4, 2020
1 parent 85c2c0e commit 8b1ccc9
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 6 deletions.
46 changes: 46 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,51 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_tensor_dense_matmul():
# Sparse utility from scipy
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()

data = inputs[3]

rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

# Create scipy sparse Tensor(CSR)
weight_sp = csr_matrix(
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)
weight_sp = csr_matrix(weight_sp.transpose())

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)

ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs])

# If both are true means First input was dense and second was sparse
# TODO(ANSHUMAN87): Support other adjoint option too
if attr.get("adjoint_a") and attr.get("adjoint_b"):
ret = _op.transpose(ret)
else:
raise tvm.error.OpAttributeUnImplemented(
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
attr.get("adjoint_a"), attr.get("adjoint_b")
)
)

return ret

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2411,6 +2456,7 @@ def _impl(inputs, attr, params, mod):
"SpaceToBatchND": _space_to_batch_nd(),
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,7 +2046,7 @@ def sparse_transpose(x):
Parameters
----------
x : namedtuple.
x : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the fast matrix transpose.
Returns
Expand All @@ -2055,7 +2055,9 @@ def sparse_transpose(x):
Tuple of output sparse tensor (same shape and format as input),
i.e. if CSR then output is in ([data, indices, indptr]) form
"""
return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
if hasattr(x, "indices"):
return expr.TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3)


def contrib_conv2d_winograd_without_weight_transform(
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
assert (
mb >= mi
), "Number of block rows in dense matrix must be larger than warp size: {} vs {}.".format(
warp_size, m
warp_size, mb
)
mo = ceil_div(mb, mi)
ni = 1 # TODO(tkonolige): how do I compute the number of warps per block?
Expand Down Expand Up @@ -367,9 +367,14 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
and isinstance(inputs[2], relay.Constant)
and isinstance(inputs[3], relay.Constant)
):
sparse_matrix = sp.bsr_matrix(
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
)
if len(inputs[1].data.asnumpy().shape) == 1:
sparse_matrix = sp.csr_matrix(
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
).tobsr()
else:
sparse_matrix = sp.bsr_matrix(
(inputs[1].data.asnumpy(), inputs[2].data.asnumpy(), inputs[3].data.asnumpy())
)
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
sparse_matrix = pad_sparse_matrix(sparse_matrix, warp_size)
return relay.nn._make.sparse_dense_padded(
Expand Down
56 changes: 56 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,62 @@ def test_forward_batch_matmul():
_test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True)


#######################################################################
# SparseTensorDenseMatMul
# ----------------------------------


def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=False):
""" One iteration of sparse_dense_matmul """

# TODO(ANSHUMAN87): Support adjoint options too
for adjoint_a in [False]:
for adjoint_b in [False]:
with tf.Graph().as_default():
A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape)
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")

if flip:
result = tf.sparse.sparse_dense_matmul(
B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b
)
else:
result = tf.sparse.sparse_dense_matmul(
A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b
)

B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)

# TODO(ANSHUMAN87): There is an issue in cuda scheduling for csr, work in progress
compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)


def test_forward_sparse_dense_matmul():
""" sparse_dense_matmul op test"""
###################################################################
#
# In order to create a SparseTensor, it requires 3 input as below:
# SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
#
# Above Sparse can be represented in Dense as below :
# [[1, 0, 0, 0]
# [0, 0, 2, 0]
# [0, 0, 0, 0]]
#
# ------------------------------------------------------------------

# TODO(ANSHUMAN87): False case for flip need to be supported
# _test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [4, 3], "float32")
_test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 5], [4, 3], "float32", True)
_test_sparse_dense_matmul([[0, 0], [1, 2]], [4.0, 8.0], [3, 3], [3, 3], "float32", True)
_test_sparse_dense_matmul(
[[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [5, 5], [5, 5], "float32", True
)
_test_sparse_dense_matmul(
[[0, 0], [1, 3], [4, 3]], [3.0, 6.0, 9.0], [9, 5], [7, 9], "float32", True
)


#######################################################################
# StridedSlice
# ------------
Expand Down

0 comments on commit 8b1ccc9

Please sign in to comment.