diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index b5c83b72ab8d4..94c413b6df6cb 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1102,10 +1102,6 @@ def gather_nd_shape_func(attrs, inputs, _): batch_dims = get_const_int(attrs.batch_dims) index_rank = get_const_int(attrs.index_rank) - assert ( - index_rank > 0 - ), "index_rank needs to be specified for dynamic gather_nd" + assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" - return [ - _gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank)) - ] + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 55ea86b47a7f9..74fb44fc2232a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, index_rank=-1): +def gather_nd(data, indices, batch_dims=0, index_rank=None): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, index_rank=-1): batch_dims : int The number of batch dimensions. - index_rank : int + index_rank : int, optional The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e534fd6c476c5..10fe5e543dfc7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3373,11 +3373,12 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int index_rank = -1) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, + Optional index_rank = NullValue()) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; - attrs->index_rank = Integer(index_rank); + attrs->index_rank = index_rank; return Call(op, {data, indices}, Attrs(attrs)); }