Skip to content

Commit

Permalink
rename to index_rank and make it Optional
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 9e06b84 commit 968f3bd
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
Integer num_indices_per_tuple;
Optional<Integer> index_rank;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
TVM_ATTR_FIELD(num_indices_per_tuple)
.set_default(Integer(-1))
TVM_ATTR_FIELD(index_rank)
.set_default(NullValue<Integer>())
.describe(
"The size of an indexing tuple, which is a fixed value. Only needed when the number of "
"indexting tuples is dynamic.");
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,8 +1418,8 @@ def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
num_indices_per_tuple = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, num_indices_per_tuple)
index_rank = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, index_rank)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,14 +1077,14 @@ def unique_shape_func(attrs, inputs, _):


@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tuple):
def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank):
ndim = data_shape.shape[0]
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in gather_nd op attributes.
mdim = num_indices_per_tuple
mdim = index_rank
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
Expand All @@ -1100,12 +1100,12 @@ def gather_nd_shape_func(attrs, inputs, _):
Shape func for gather_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dims)
num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple)
index_rank = get_const_int(attrs.index_rank)

assert (
num_indices_per_tuple > 0
), "num_indices_per_tuple needs to be specified for dynamic gather_nd"
index_rank > 0
), "index_rank needs to be specified for dynamic gather_nd"

return [
_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))
_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))
]
6 changes: 3 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1):
def gather_nd(data, indices, batch_dims=0, index_rank=-1):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand All @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1):
batch_dims : int
The number of batch dimensions.
num_indices_per_tuple : int
index_rank : int
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.
Expand All @@ -1112,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1):
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices, batch_dims, num_indices_per_tuple)
return _make.gather_nd(data, indices, batch_dims, index_rank)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3373,11 +3373,11 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int num_indices_per_tuple = -1) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int index_rank = -1) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->num_indices_per_tuple = num_indices_per_tuple;
attrs->index_rank = Integer(index_rank);
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down

0 comments on commit 968f3bd

Please sign in to comment.