Skip to content

Commit

Permalink
Set TOpPattern=kOpaque for scatter_nd (#7464)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored Feb 18, 2021
1 parent b7e0cfb commit 84c4b15
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,9 @@ Expr MakeScatterND(Expr data, Expr indices, const Array<Integer> out_shape) {

TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);

// scatter_nd operator has extern schedules for CPU and GPU devices.
// Fusing extern schedules with Injective schedules leads to errors.
// So, converting the scatter_nd to Opaque to prevent compilation failures
RELAY_REGISTER_OP("scatter_nd")
.describe(R"code(Scatter elements or slices from data and store to a tensor
whose shape is defined by indices.
Expand All @@ -1158,7 +1161,7 @@ Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with sh
.add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(3)
.add_type_rel("ScatterND", ScatterNDRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kOpaque);

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,17 +1391,46 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5)
op_res = intrp.evaluate(func)(data_np, indices_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)

def verify_scatter_nd_with_stack(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5):
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices_vars = [
relay.var("ind{i}", shape=v.shape, dtype=str(v.dtype)) for i, v in enumerate(indices_np)
]

# test if scatter_nd works in case indices are prepared by another Relay operator
indices = relay.op.stack(indices_vars, axis=0)
out = relay.op.scatter_nd(data, indices, shape)
func = relay.Function(
[
data,
]
+ indices_vars,
out,
)

fargs = [
data_np,
]
for a in indices_np:
fargs.append(a)
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*fargs)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol, atol=atol)

data = np.array([2, 3, 0])
indices = np.array([[1, 1, 0], [0, 1, 0]])
shape = (2, 2)
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)

data = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
indices = np.array([[0, 1], [1, 1]])
shape = (2, 2, 2, 2)
out = np.array([[[[0, 0], [0, 0]], [[1, 2], [3, 4]]], [[[0, 0], [0, 0]], [[5, 6], [7, 8]]]])
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)

data = np.reshape(np.arange(1560 * 3), (3, 1560)).astype("float32")
indices = np.array([[1, 0, 0]])
Expand All @@ -1411,6 +1440,7 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5)
out[0, :] += data[1, :]
out[0, :] += data[2, :]
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)

data = np.ones((5, 3)).astype("float64")
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype("int64")
Expand All @@ -1420,6 +1450,7 @@ def verify_scatter_nd(data_np, indices_np, shape, ref_res, rtol=1e-5, atol=1e-5)
for j in range(data.shape[1]):
out[indices[0, i], indices[1, i], j] += data[i, j]
verify_scatter_nd(data, indices, shape, out)
verify_scatter_nd_with_stack(data, indices, shape, out)


if __name__ == "__main__":
Expand Down

0 comments on commit 84c4b15

Please sign in to comment.