diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 2a421182d5f03..c4cb10aed3a47 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,11 +146,11 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; - Integer gather_dim; + Integer num_indices_per_tuple; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); - TVM_ATTR_FIELD(gather_dim) + TVM_ATTR_FIELD(num_indices_per_tuple) .set_default(Integer(-1)) .describe( "The size of an indexing tuple, which is a fixed value. Only needed when the number of " diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e67ce52898dba..6feed09269d5b 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1418,8 +1418,8 @@ def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - gather_dim = indices_shape[-1] - return _op.gather_nd(data, indices, batch_dims, gather_dim) + num_indices_per_tuple = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, num_indices_per_tuple) @classmethod def _impl_v1(cls, inputs, attr, params): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a3caaf6344990..60d642d829254 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1077,14 +1077,14 @@ def unique_shape_func(attrs, inputs, _): @script -def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): +def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tuple): ndim = data_shape.shape[0] # using mdim = indices_shape[0] wouldn't work because a rank cannot # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded # in `gather_nd` op attribute. - mdim = gather_dim + mdim = num_indices_per_tuple kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): @@ -1100,6 +1100,6 @@ def gather_nd_shape_func(attrs, inputs, _): Shape func for ghater_nd operator. """ batch_dims = get_const_int(attrs.batch_dims) - gather_dim = get_const_int(attrs.gather_dim) - assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" - return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] + num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple) + assert num_indices_per_tuple > 0, "num_indices_per_tuple needs to be specified for dynamic gather_nd" + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8c2f0e9bb3305..fdd86b316353a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, gather_dim=-1): +def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1): batch_dims : int The number of batch dimensions. - gather_dim : int + num_indices_per_tuple : int The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. @@ -1112,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims, gather_dim) + return _make.gather_nd(data, indices, batch_dims, num_indices_per_tuple) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c6efb9ee64bb2..2128685d18f35 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3373,11 +3373,11 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int num_indices_per_tuple = -1) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; - attrs->gather_dim = gather_dim; + attrs->num_indices_per_tuple = num_indices_per_tuple; return Call(op, {data, indices}, Attrs(attrs)); }