diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1e782a568fe9..12db859d1ae1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1146,6 +1146,9 @@ Expr MakeScatterND(Expr data, Expr indices, const Array out_shape) { TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND); +// scatter_nd operator has extern schedules for CPU and GPU devices. +// Fusing extern schedules with Injective schedules leads to errors. +// So, converting the scatter_nd to Opaque to prevent compilation failures RELAY_REGISTER_OP("scatter_nd") .describe(R"code(Scatter elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1158,7 +1161,7 @@ Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with sh .add_argument("indices", "Tensor", "The indices tensor.") .set_support_level(3) .add_type_rel("ScatterND", ScatterNDRel) - .set_attr("TOpPattern", kInjective); + .set_attr("TOpPattern", kOpaque); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 625c47240326..94fac3ba1264 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1391,17 +1391,46 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5) op_res = intrp.evaluate(func)(data_np, indices_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) + def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5): + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices_vars = [ + relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np) + ] + + # test if scatter_nd works in case indices are prepared by another Relay operator + indices = relay.op.stack(indices_vars, axis=0) + out = relay.op.scatter_nd(data, indices, shape) + func = relay.Function( + [ + data, + ] + + indices_vars, + out, + ) + + fargs = [ + data_np, + ] + for a in indices_np: + fargs.append(a) + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(*fargs) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol) + data = np.array([2, 3, 0]) indices = np.array([[1, 1, 0], [0, 1, 0]]) shape = (2, 2) out = np.array([[0, 0], [2, 3]]) verify_scatter_nd(data, indices, shape, out) + verify_scatter_nd_with_stack(data, indices, shape, out) data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) indices = np.array([[0, 1], [1, 1]]) shape = (2, 2, 2, 2) out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]]) verify_scatter_nd(data, indices, shape, out) + verify_scatter_nd_with_stack(data, indices, shape, out) data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32") indices = np.array([[1, 0, 0]]) @@ -1411,6 +1440,7 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5) out[0, :] += data[1, :] out[0, :] += data[2, :] verify_scatter_nd(data, indices, shape, out) + verify_scatter_nd_with_stack(data, indices, shape, out) data = np.ones((5, 3)).astype("float64") indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64") @@ -1420,6 +1450,7 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5) for j in range(data.shape[1]): out[indices[0, i], indices[1, i], j] += data[i, j] verify_scatter_nd(data, indices, shape, out) + verify_scatter_nd_with_stack(data, indices, shape, out) if __name__ == "__main__":