Skip to content

Commit

Permalink
Merge branch 'gather_nd_shape_func' into tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
2 parents 61e70b8 + 06ac205 commit d7180f2
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 25 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
Optional<Integer> index_rank;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
TVM_ATTR_FIELD(index_rank)
.set_default(NullValue<Integer>())
.describe(
"The size of an indexing tuple, which is a fixed value. Only needed when the number of "
"indexting tuples is dynamic.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,10 @@ class GatherND(OnnxOpConverter):
@classmethod
def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(data, indices, batch_dims)
index_rank = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, index_rank)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,3 +1074,34 @@ def unique_shape_func(attrs, inputs, _):
return _unique_with_counts_shape(inputs[0])
else:
return _unique_shape(inputs[0])


@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank):
ndim = data_shape.shape[0]
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in gather_nd op attributes.
mdim = index_rank
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
out_shape[i - 1] = indices_shape[i]
for i in range(mdim + batch_dims, ndim):
out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i]
return out_shape


@_reg.register_shape_func("gather_nd", False)
def gather_nd_shape_func(attrs, inputs, _):
"""
Shape func for gather_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dims)
index_rank = get_const_int(attrs.index_rank)

assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd"

return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))]
8 changes: 6 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices, batch_dims=0):
def gather_nd(data, indices, batch_dims=0, index_rank=None):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand All @@ -1087,6 +1087,10 @@ def gather_nd(data, indices, batch_dims=0):
batch_dims : int
The number of batch dimensions.
index_rank : int, optional
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.
Returns
-------
ret : relay.Expr
Expand All @@ -1108,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0):
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices, batch_dims)
return _make.gather_nd(data, indices, batch_dims, index_rank)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
from ..te import extern, hybrid


Expand Down Expand Up @@ -206,12 +206,16 @@ def _verify_scatter_nd_inputs(data, indices, updates):
f"the length of the shape of the output ({len(shape)})."
)
for i in range(len(indices.shape) - 1):
if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):
continue
assert indices.shape[i + 1] == updates.shape[i], (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
f"updates[{i}] ({updates.shape[i]})."
)
for i in range(mdim, len(data.shape)):
data_ind = i - mdim + len(indices.shape) - 1
if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var):
continue
assert updates.shape[data_ind] == data.shape[i], (
f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension "
f"of out_shape[{i}] ({data.shape[i]})."
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3373,10 +3373,12 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0,
Optional<Integer> index_rank = NullValue<Integer>()) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->index_rank = index_rank;
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.relay.testing import run_infer_type as infer_type

from utils.assert_diagnostic import DiagnosticTesting
from utils import ref_funcs


def int32(val):
Expand Down Expand Up @@ -1703,5 +1704,52 @@ def verify_all_class_non_max_suppression(
)


@tvm.testing.uses_gpu
def test_gather_nd():
def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0):
x = relay.var("x", relay.TensorType(data_shape, "float32"))
y = relay.var("y", relay.TensorType(indices_shape, "int32"))
z = relay.gather_nd(x, y, batch_dims, indices_shape[0])

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)

data_np = np.random.uniform(size=data_shape_np).astype("float32")
indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32")

ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims)
check_result([data_np, indices_np], mod, [ref_res])

verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3))
verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3))
verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1)
verify_gather_nd(
(relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2
)


@tvm.testing.uses_gpu
def test_scatter_nd():
def verify_scatter_nd(data_np, indices_np, updates_np, ref_res):
indices_shape = (2, relay.Any())
updates_shape = (relay.Any(),)
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype)))
updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype)))

out = relay.op.scatter_nd(data, indices, updates, "add")

mod = tvm.IRModule()
mod["main"] = relay.Function([data, indices, updates], out)

check_result([data_np, indices_np, updates_np], mod, [ref_res])

data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]])
updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, updates, out)


if __name__ == "__main__":
pytest.main([__file__])
22 changes: 2 additions & 20 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.error import TVMError
from tvm.relay import create_executor, transform
from tvm.relay.testing import check_grad, run_infer_type
from utils import ref_funcs


def test_zeros_ones():
Expand Down Expand Up @@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
else:
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])

ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)

out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
else:
ref_res = x_data[tuple(y_data)]
ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/utils/ref_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np


def gather_nd(data_np, indices_np, batch_dims=0):
"""gather_nd implemented using numpy"""
data_shape = data_np.shape
indices_shape = indices_np.shape

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:])
indices_np_reshape = np.reshape(
indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :]
)

ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape)

out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np)
else:
ref_res = data_np[tuple(indices_np)]

return ref_res

0 comments on commit d7180f2

Please sign in to comment.