From 6db2d3a404a3f4bcabffd440e276a1a6043c1778 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Mon, 2 Nov 2020 13:23:01 +0530 Subject: [PATCH] [TENSORFLOW]Sparse2Dense support (#5767) * [TENSORFLOW]Sparse2Dense support * Formatting issues fixed --- python/tvm/relay/frontend/tensorflow.py | 13 ++++ .../frontend/tensorflow/test_forward.py | 77 +++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6a23c8da9739..2c7adf03bad8 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1281,6 +1281,18 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_to_dense(): + def _impl(inputs, attr, params, mod): + sparse_indices = inputs[0] + sparse_values = inputs[2] + default_value = inputs[3] + output_shape = attr["_output_shapes"][0] + + return _op.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) + + return _impl + + def _bias_add(): def _impl(inputs, attr, params, mod): # Must expand for proper broadcasting in NCHW. @@ -2394,6 +2406,7 @@ def _impl(inputs, attr, params, mod): "Softsign": _softsign(), "SpaceToBatchND": _space_to_batch_nd(), "SpaceToDepth": _space_to_depth(), + "SparseToDense": _sparse_to_dense(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 94c2c440e4d1..6697cfd0d36f 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3968,6 +3968,83 @@ def test_forward_dilation(): _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID") +####################################################################### +# Sparse To Dense +# --------------- +def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): + with tf.Graph().as_default(): + indices = tf.placeholder( + shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices" + ) + values = tf.placeholder( + shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values" + ) + oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)) + + if default_value == None: + output = tf.sparse_to_dense(indices, oshape, values) + compare_tf_with_tvm( + [sparse_indices, sparse_values], ["indices:0", "values:0"], output.name + ) + else: + dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value") + output = tf.sparse_to_dense(indices, oshape, values, dv) + compare_tf_with_tvm( + [sparse_indices, sparse_values, default_value], + ["indices:0", "values:0", "default_value:0"], + output.name, + ) + + +def test_forward_sparse_to_dense(): + # scalar + _test_sparse_to_dense( + sparse_indices=np.int32(1), + sparse_values=np.int32(3), + default_value=np.int32(0), + output_shape=np.array([5]).astype("int32"), + ) + + # vector + _test_sparse_to_dense( + sparse_indices=np.array([0, 1, 4]).astype("int32"), + sparse_values=np.array([3, 3, 3]).astype("int32"), + default_value=np.int32(0), + output_shape=np.array([5]).astype("int32"), + ) + + # vector nXd + _test_sparse_to_dense( + sparse_indices=np.array([[0, 0], [1, 2]]).astype("int32"), + sparse_values=np.array([1, 2]).astype("int32"), + default_value=np.int32(0), + output_shape=np.array([3, 4]).astype("int32"), + ) + + _test_sparse_to_dense( + sparse_indices=np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"), + sparse_values=np.array([1, 2]).astype("int32"), + default_value=np.int32(4), + output_shape=np.array([2, 3, 4]).astype("int32"), + ) + + # floats + _test_sparse_to_dense( + sparse_indices=np.array([0, 1, 4]).astype("int32"), + sparse_values=np.array([3.1, 3.1, 3.1]).astype("float32"), + default_value=np.float32(3.5), + output_shape=np.array([5]).astype("int32"), + ) + + # default value not specified + _test_sparse_to_dense( + sparse_indices=np.array([0, 1, 4]).astype("int32"), + sparse_values=np.array([3.1, 3.1, 3.1]).astype("float32"), + default_value=None, + output_shape=np.array([5]).astype("int32"), + ) + + ####################################################################### # infinity ops # ------------