From 23ac5cf612ebd5e73ce13986da4db24db363575f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 03:28:04 +0900 Subject: [PATCH] add shape constraint for batch_dim and update doc --- src/relay/op/tensor/transform.cc | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 137bb73ddc34..fb99fcfe5d3b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3353,6 +3353,11 @@ bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const auto param = attrs.as(); ICHECK(param != nullptr); + for (int i = 0; i < param->batch_dims->value; ++i) { + ICHECK(reporter->AssertEQ( + data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple + } + Array oshape; for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i) @@ -3381,9 +3386,15 @@ RELAY_REGISTER_OP("gather_nd") .describe(R"code(Gather elements or slices from data and store to a tensor whose shape is defined by indices. -Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with -shape (M, Y_0, ..., Y_{K-1}), the output will have shape -(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, +Optionally, batch_dims, the number of batch dimensions, can be given, whose +default value is 0. + +Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}), +(M, Y_0, ..., Y_{K-1}) respectively. When B > 0, indexing will start from the B-th axis, +and it must be the case that X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. + +The output will have shape +(Y_0, ..., Y_{B-1}, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. If M + B == N, output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) .set_num_inputs(2)