Skip to content

Commit

Permalink
[TF] Add DenseBincount support (#12728)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored Sep 9, 2022
1 parent 1c5ffc6 commit cb08a12
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
55 changes: 55 additions & 0 deletions python/tvm/relay/frontend/tensorflow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2868,6 +2868,60 @@ def _impl(inputs, attr, params, mod):
return _impl


def _dense_bincount():
def _impl(inputs, attr, params, mod):
input = inputs[0] # input: int32, int64. 1D or 2D int Tensor
size = inputs[1] # size: non-negative int scalar Tensor
# weights: int32, int64, float32, or float64 Tensor with the same shape as arr
# or a length-0 Tensor, in which case it acts as all weights equal to 1.
weights = inputs[2]
# Returns: Output: 1D Tensor with length equal to size
# or 2D Tensor with [batch_size, size].
# The counts or summed weights for each value in the range [0, size).

input_dtype = _infer_type(input, mod).checked_type.dtype
input_shape = _infer_shape(input, mod)
is_2d_input = len(input_shape) == 2

if input_dtype == "int64":
warnings.warn(
"Casting an int64 input to int32, since we do not have int64 atomic add"
"needed for bincount yet."
)
input = _op.cast(input, "int32")

is_weights_zero_tensor = True
if weights:
weights_shape = _infer_shape(weights, mod)
is_weights_zero_tensor = weights_shape == (0,)

# Output should have the same dtype as weights.
if is_weights_zero_tensor:
# if weights are length-0 Tensor - output dtype is float32
out_dtype = "float32"
updates = _op.cast(_op.ones_like(input), out_dtype)
else:
out_dtype = _infer_type(weights, mod).checked_type.dtype
updates = weights

if is_2d_input:
batch_arr = _op.take(_op.shape_of(input), _expr.const([0]))
size_arr = _op.reshape(size, [1])
counts_shape = _op.concatenate([batch_arr, size_arr], axis=0)
counts = _op.zeros(counts_shape, out_dtype)
out = _op.scatter_add(counts, input, updates, axis=1)
else:
counts_shape = _op.reshape(size, [1])
counts = _op.zeros(counts_shape, out_dtype)
out = _op.scatter_add(counts, input, updates, axis=0)

if attr["binary_output"]:
out = _op.cast(_op.cast(out, "bool"), out_dtype)
return out

return _impl


# _convert_map defines maps of name to converter functor(callable)
# for 1 to 1 mapping, use Renamer if nothing but name is different
# use AttrCvt if attributes need to be converted
Expand Down Expand Up @@ -2913,6 +2967,7 @@ def _impl(inputs, attr, params, mod):
"Cosh": AttrCvt("cosh"),
"CropAndResize": _crop_and_resize(),
"DecodeJpeg": _decode_image(),
"DenseBincount": _dense_bincount(),
"DepthToSpace": _depth_to_space(),
"DepthwiseConv2dNative": _conv("depthwise"),
"Dilation2D": _dilation2d(),
Expand Down
41 changes: 41 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5758,5 +5758,46 @@ def test_invert_permutation():
compare_tf_with_tvm(x, "Placeholder:0", out_name, no_gpu=False)


#######################################################################
# DenseBincount
# ----


def _test_dense_bincount(in_shape, size, weights, binary_output):
with tf.Graph().as_default():
inputs = []
data = []
inputs.append(tf.placeholder(shape=in_shape, dtype="int32", name="input0"))
data.append(np.random.uniform(0, size, size=in_shape).astype("int32"))
inputs.append(tf.placeholder(shape=(), dtype="int32", name="size"))
data.append(np.array(size, "int32"))
if weights:
inputs.append(tf.placeholder(shape=in_shape, dtype="float32", name="weights"))
data.append(np.reshape(weights, in_shape).astype("float32"))
else:
inputs.append(tf.placeholder(shape=(0,), dtype="float32", name="weights"))
data.append(np.array([], "float32"))
result = tf.raw_ops.DenseBincount(
input=data[0],
size=data[1],
weights=data[2],
binary_output=binary_output,
)
compare_tf_with_tvm(data, [a.name for a in inputs], result.name, mode="vm")


def test_forward_dense_bincount():
"""Test DenseBincount Op"""
for binary_output in [False, True]:
# 2D input
_test_dense_bincount((3, 10), 20, [1.0] * 30, binary_output)
_test_dense_bincount((3, 10), 20, [1.5] * 30, binary_output)
_test_dense_bincount((3, 10), 20, None, binary_output)
# 1D input
_test_dense_bincount((10,), 20, [1.0] * 10, binary_output)
_test_dense_bincount((10,), 20, [1.5] * 10, binary_output)
_test_dense_bincount((10,), 20, None, binary_output)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit cb08a12

Please sign in to comment.