Skip to content

Commit

Permalink
[Relay, ONNX] Support gather_nd batch_dims attribute for TF/ONNX (apa…
Browse files Browse the repository at this point in the history
…che#8084)

* Add GatherND batch_dim support

* adding tests

* test working

* improved reference code

* refactor ref func

* batch dim 2 tests from tf all passed

* batch_dim -> batch_dims

* add example

* minor change

* add onnx test

* fix onnx version

* fix lint

* remove move on batch_dims

* fix pylint

* fix compiler warning

* add shape constraint for batch_dim and update doc

* make the output shape doc clearer
  • Loading branch information
masahi authored and Trevor Morris committed Jun 17, 2021
1 parent 1e52b37 commit 4f2b9a8
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 21 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
}
};

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
}
};
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
Expand Down
10 changes: 7 additions & 3 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1238,13 +1238,14 @@ inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
*
* \param data The source array.
* \param indices The indices of the values to extract.
* \param batch_dims The number of batch dimensions.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the gather_nd operation
*/
inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd",
std::string tag = kInjective) {
inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dims = 0,
std::string name = "T_gather_nd", std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
ICHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
Expand All @@ -1255,7 +1256,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
for (size_t i = 1; i < ndim_i; ++i) {
out_shape.push_back(indices->shape[i]);
}
for (size_t i = indices_dim0; i < ndim_d; ++i) {
for (size_t i = indices_dim0 + batch_dims; i < ndim_d; ++i) {
out_shape.push_back(data->shape[i]);
}
return compute(
Expand All @@ -1267,6 +1268,9 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
indices_position.push_back(out_index[i]);
}
Array<PrimExpr> real_indices;
for (size_t i = 0; i < static_cast<size_t>(batch_dims); ++i) {
real_indices.push_back(out_index[i]);
}
for (size_t i = 0; i < indices_dim0; ++i) {
indices_position.Set(0, make_const(DataType::Int(32), i));
if (indices->dtype.is_int()) {
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,11 +1413,20 @@ def _impl_v1(cls, inputs, attr, params):
class GatherND(OnnxOpConverter):
"""Operator converter for GatherND."""

@classmethod
def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(data, indices, batch_dims)

@classmethod
def _impl_v1(cls, inputs, attr, params):
indices_dims = len(infer_shape(inputs[1]))
indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(inputs[0], indices)
return cls._impl_common(inputs[0], inputs[1])

@classmethod
def _impl_v12(cls, inputs, attr, params):
batch_dims = attr.get("batch_dims", 0)
return cls._impl_common(inputs[0], inputs[1], batch_dims)


class Scatter(OnnxOpConverter):
Expand Down
11 changes: 9 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices):
def gather_nd(data, indices, batch_dims=0):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand All @@ -1123,6 +1123,9 @@ def gather_nd(data, indices):
indices : relay.Expr
The shape of output tensor.
batch_dims : int
The number of batch dimensions.
Returns
-------
ret : relay.Expr
Expand All @@ -1139,8 +1142,12 @@ def gather_nd(data, indices):
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
data = [[[0,1],[2,3]],[[4,5],[6,7]]]
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices)
return _make.gather_nd(data, indices, batch_dims)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
38 changes: 30 additions & 8 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3402,21 +3402,34 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const size_t kdim = indices->shape.size() - 1;
ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy.";

const auto param = attrs.as<GatherNDAttrs>();
ICHECK(param != nullptr);

for (int i = 0; i < param->batch_dims->value; ++i) {
ICHECK(reporter->AssertEQ(
data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple
}

Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
return {topi::gather_nd(inputs[0], inputs[1])};
const auto* param = attrs.as<GatherNDAttrs>();
ICHECK(param);
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) {
static const Op& op = Op::Get("gather_nd");
return Call(op, {data, indices}, {});
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
return Call(op, {data, indices}, Attrs(attrs));
}

TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND);
Expand All @@ -3425,10 +3438,19 @@ RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.
Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
Optionally, batch_dims, the number of batch dimensions, can be given, whose
default value is 0.
Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}),
(M, Y_0, ..., Y_{K-1}) respectively.
When B > 0, indexing will start from the B-th axis, and it must be the case that
X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. The output will have a shape
(X_0, ..., X_{B-1}, Y_B, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N.
When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}).
In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
Expand Down
25 changes: 23 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import numpy as np
import pytest
import scipy
Expand Down Expand Up @@ -218,6 +219,12 @@ def make_constant_node(name, data_type, dims, vals):
)


def is_version_greater_than(ver):
return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", onnx.__version__)[0]) > "".join(
re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0]
)


@tvm.testing.uses_gpu
def test_reshape():
in_shape = (4, 3, 3, 4)
Expand Down Expand Up @@ -1002,12 +1009,16 @@ def test_isnan():
_test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {})


def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"):
def verify_gather_nd(in_shape, indices, out_shape, dtype="float32", batch_dims=0, opset=11):
x = np.random.uniform(size=in_shape).astype(dtype)
indices = np.array(indices, dtype="int64")

y = helper.make_node("GatherND", ["in", "indices"], ["out"])

if opset >= 12:
batch_dims_attr = helper.make_attribute("batch_dims", batch_dims)
y.attribute.append(batch_dims_attr)

graph = helper.make_graph(
[y],
"gather_test",
Expand All @@ -1024,7 +1035,7 @@ def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"):
],
)
model = helper.make_model(graph, producer_name="gather_test")
verify_with_ort_with_inputs(model, [x, indices], [out_shape])
verify_with_ort_with_inputs(model, [x, indices], [out_shape], opset=opset)


@tvm.testing.uses_gpu
Expand All @@ -1034,6 +1045,16 @@ def test_gather_nd():
verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2])
verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2])

if is_version_greater_than("1.6.0"):
verify_gather_nd([2, 2, 2], [[1], [0]], [2, 2], batch_dims=1, opset=12)
verify_gather_nd(
(3, 2, 2, 3, 4),
np.random.randint(low=0, high=2, size=(3, 2, 3), dtype="int64"),
(3, 2),
batch_dims=2,
opset=12,
)


@tvm.testing.uses_gpu
def test_onehot():
Expand Down
55 changes: 52 additions & 3 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,14 +1284,40 @@ def verify_gather(data, axis, indices, ref_res):

@tvm.testing.uses_gpu
def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
z = relay.gather_nd(x, y)
z = relay.gather_nd(x, y, batch_dims)

func = relay.Function([x, y], z)

x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = x_data[tuple(y_data)]

if y_data:
y_data = np.array(y_data, dtype="int32")
else:
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])

ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)

out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
else:
ref_res = x_data[tuple(y_data)]

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand All @@ -1304,6 +1330,29 @@ def verify_gather_nd(xshape, yshape, y_data):
verify_gather_nd((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify_gather_nd((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])

# Examples from tensorflow gather_nd doc
# https://www.tensorflow.org/api_docs/python/tf/gather_nd
verify_gather_nd((2, 2, 2), (1, 2), [[1, 0]], 1)
verify_gather_nd((2, 2, 2), (1, 2, 1), [[[1], [0]]], 1)
verify_gather_nd((2, 2, 2), (2, 2, 1), [[[1], [0]], [[0], [1]]], 1)

# Test cases from tensorflow gather_nd tests kernel_tests/array_ops_test.py
verify_gather_nd((2, 2, 2), (1, 2), None, 1)
verify_gather_nd((2, 2, 2), (2, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (3, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (2, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (1, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (3, 2, 1), None, 1)
verify_gather_nd((2, 2, 3, 2), (2, 2, 2), None, 1)
verify_gather_nd((2, 2, 3, 2), (1, 2, 3), None, 1)

verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (3, 3, 2, 1), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (2, 3, 2, 2), None, 2)
verify_gather_nd((3, 2, 2, 3, 4), (1, 3, 2, 3), None, 2)


def _verify_infiniteness_ops(relay_op, ref_op):
for dtype in ["float32", "float16", "float16", "int32", "int16"]:
Expand Down

0 comments on commit 4f2b9a8

Please sign in to comment.