Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set TOpPattern=kOpaque for scatter_nd #7464

Merged
merged 1 commit into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,9 @@ Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) {

TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);

// scatter_nd operator has extern schedules for CPU and GPU devices.
// Fusing extern schedules with Injective schedules leads to errors.
// So, converting the scatter_nd to Opaque to prevent compilation failures
RELAY_REGISTER_OP("scatter_nd")
.describe(R"code(Scatter elements or slices from data and store to a tensor
whose shape is defined by indices.
Expand All @@ -1158,7 +1161,7 @@ Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with sh
.add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(3)
.add_type_rel("ScatterND", ScatterNDRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kOpaque);

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,17 +1391,46 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5)
op_res = intrp.evaluate(func)(data_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)

def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices_vars = [
relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
]

# test if scatter_nd works in case indices are prepared by another Relay operator
indices = relay.op.stack(indices_vars, axis=0)
out = relay.op.scatter_nd(data, indices, shape)
func = relay.Function(
[
data,
]
+ indices_vars,
out,
)

fargs = [
data_np,
]
for a in indices_np:
fargs.append(a)
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*fargs)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)

data = np.array([2, 3, 0])
indices = np.array([[1, 1, 0], [0, 1, 0]])
shape = (2, 2)
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)

data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
indices = np.array([[0, 1], [1, 1]])
shape = (2, 2, 2, 2)
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)

data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
indices = np.array([[1, 0, 0]])
Expand All @@ -1411,6 +1440,7 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5)
out[0, :] += data[1, :]
out[0, :] += data[2, :]
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)

data = np.ones((5, 3)).astype("float64")
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
Expand All @@ -1420,6 +1450,7 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5)
for j in range(data.shape[1]):
out[indices[0, i], indices[1, i], j] += data[i, j]
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)


if __name__ == "__main__":
Expand Down