Skip to content

Commit

Permalink
add shape constraint for batch_dim and update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 20, 2021
1 parent 5692540 commit 23ac5cf
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3353,6 +3353,11 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const auto param = attrs.as<GatherNDAttrs>();
ICHECK(param != nullptr);

for (int i = 0; i < param->batch_dims->value; ++i) {
ICHECK(reporter->AssertEQ(
data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple
}

Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i)
Expand Down Expand Up @@ -3381,9 +3386,15 @@ RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.
Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
Optionally, batch_dims, the number of batch dimensions, can be given, whose
default value is 0.
Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}),
(M, Y_0, ..., Y_{K-1}) respectively. When B > 0, indexing will start from the B-th axis,
and it must be the case that X_0, ... X_{B-1} == Y_0, ... Y_{B-1}.
The output will have shape
(Y_0, ..., Y_{B-1}, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. If M + B == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
Expand Down

0 comments on commit 23ac5cf

Please sign in to comment.