Skip to content

Commit

Permalink
[Relay][Op][Bug] Fix missing return in scatter_nd cuda strategy (apac…
Browse files Browse the repository at this point in the history
…he#7447)

* fix missing return in scatter_nd cuda strategy

* add Relay test for scatter_nd, fix documentation
  • Loading branch information
altanh authored and trevor-m committed Mar 2, 2021
1 parent aee3fd3 commit d89af0a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 39 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ def scatter_nd_cuda(attrs, inputs, out_type, target):
name="scatter_nd.cuda",
plevel=10,
)
return strategy


@sort_strategy.register(["cuda", "gpu"])
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def scatter_nd(data, indices, out_shape):
indices : relay.Expr
The index locations to update.
out_shape : relay.Expr
out_shape : Union[Tuple[int], List[int]]
Output shape of the scatter.
Returns
Expand Down
84 changes: 46 additions & 38 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,42 +1409,50 @@ def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e
verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64")


@tvm.testing.parametrize_targets
def test_scatter_nd(target, ctx):
def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype))

out = relay.op.scatter_nd(data, indices, shape)
func = relay.Function([data, indices], out)

for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)

data = np.array([2, 3, 0])
indices = np.array([[1, 1, 0], [0, 1, 0]])
shape = (2, 2)
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, shape, out)

data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
indices = np.array([[0, 1], [1, 1]])
shape = (2, 2, 2, 2)
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
verify_scatter_nd(data, indices, shape, out)

data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
indices = np.array([[1, 0, 0]])
shape = (2, 1560)
out = np.zeros(shape).astype("float32")
out[1, :] += data[0, :]
out[0, :] += data[1, :]
out[0, :] += data[2, :]
verify_scatter_nd(data, indices, shape, out)

data = np.ones((5, 3)).astype("float64")
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
shape = (2, 7, 3)
out = np.zeros(shape).astype("float64")
for i in range(indices.shape[1]):
for j in range(data.shape[1]):
out[indices[0, i], indices[1, i], j] += data[i, j]
verify_scatter_nd(data, indices, shape, out)


if __name__ == "__main__":
test_cast()
test_zeros_ones()
test_unary_identity()
test_clip()
test_transpose_infer_type()
test_transpose()
test_reshape_infer_type()
test_reshape()
test_reshape_fail()
test_reshape_like_infer_type()
test_reshape_like()
test_take_infer_type()
test_take()
test_full_infer_type()
test_full()
test_full_like_infer_type()
test_full_like()
test_infer_type_leaky_relu()
test_infer_type_prelu()
test_squeeze()
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
test_arange()
test_meshgrid()
test_reverse()
test_stack()
test_tile()
test_repeat()
test_gather_nd()
test_isfinite()
test_isinf()
test_unravel_index()
test_sparse_to_dense()
test_fixed_point_multiply()
test_adv_index()
test_interpolate()
test_cumsum()
pytest.main([__file__])

0 comments on commit d89af0a

Please sign in to comment.