diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 346e93445f1c..032d2dd2c8f1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -826,6 +826,7 @@ def scatter_nd_cuda(attrs, inputs, out_type, target): name="scatter_nd.cuda", plevel=10, ) + return strategy @sort_strategy.register(["cuda", "gpu"]) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 2f7ae8c262ef..0f8c8e4ab7e1 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -321,7 +321,7 @@ def scatter_nd(data, indices, out_shape): indices : relay.Expr The index locations to update. - out_shape : relay.Expr + out_shape : Union[Tuple[int], List[int]] Output shape of the scatter. Returns diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 358c1e25b093..dc9d39a166f5 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1409,42 +1409,50 @@ def verify_cumsum(data_np, np_out, axis=None, out_dtype=None, rtol=1e-5, atol=1e verify_cumsum(data, np.cumsum(data, dtype="int64"), out_dtype="int64") +@tvm.testing.parametrize_targets +def test_scatter_nd(target, ctx): + def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5): + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype)) + + out = relay.op.scatter_nd(data, indices, shape) + func = relay.Function([data, indices], out) + + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np, indices_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) + + data = np.array([2, 3, 0]) + indices = np.array([[1, 1, 0], [0, 1, 0]]) + shape = (2, 2) + out = np.array([[0, 0], [2, 3]]) + verify_scatter_nd(data, indices, shape, out) + + data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + indices = np.array([[0, 1], [1, 1]]) + shape = (2, 2, 2, 2) + out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) + verify_scatter_nd(data, indices, shape, out) + + data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") + indices = np.array([[1, 0, 0]]) + shape = (2, 1560) + out = np.zeros(shape).astype("float32") + out[1, :] += data[0, :] + out[0, :] += data[1, :] + out[0, :] += data[2, :] + verify_scatter_nd(data, indices, shape, out) + + data = np.ones((5, 3)).astype("float64") + indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") + shape = (2, 7, 3) + out = np.zeros(shape).astype("float64") + for i in range(indices.shape[1]): + for j in range(data.shape[1]): + out[indices[0, i], indices[1, i], j] += data[i, j] + verify_scatter_nd(data, indices, shape, out) + + if __name__ == "__main__": - test_cast() - test_zeros_ones() - test_unary_identity() - test_clip() - test_transpose_infer_type() - test_transpose() - test_reshape_infer_type() - test_reshape() - test_reshape_fail() - test_reshape_like_infer_type() - test_reshape_like() - test_take_infer_type() - test_take() - test_full_infer_type() - test_full() - test_full_like_infer_type() - test_full_like() - test_infer_type_leaky_relu() - test_infer_type_prelu() - test_squeeze() - test_squeeze_infer_type() - test_squeeze_bad_axes_infer_type() - test_split_infer_type() - test_arange() - test_meshgrid() - test_reverse() - test_stack() - test_tile() - test_repeat() - test_gather_nd() - test_isfinite() - test_isinf() - test_unravel_index() - test_sparse_to_dense() - test_fixed_point_multiply() - test_adv_index() - test_interpolate() - test_cumsum() + pytest.main([__file__])