Skip to content

Commit

Permalink
add check for num of arguments for topi.take
Browse files Browse the repository at this point in the history
  • Loading branch information
zxy844288792 committed May 3, 2021
1 parent 1423c67 commit bdad09e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int a
}

ICHECK_LT(batch_dims_, a->shape.size()) << "batch_dims out of bounds";
ICHECK_GE(axis, batch_dims_) << "batch_dims must be less than or equal to axis";
ICHECK_LE(batch_dims_, axis) << "batch_dims must be less than or equal to axis";
for (int i = 0; i < batch_dims_; ++i) {
auto addr1 = a->shape[i];
auto addr2 = indices->shape[i];
Expand Down
1 change: 1 addition & 0 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) {
int batch_dims = args[2];
*rv = take(args[0], args[1], batch_dims, mode);
} else {
ICHECK_EQ(args.size(), 5) << "topi.take expects 4 or 5 arguments";
int batch_dims = args[2];
int axis = args[3];
std::string mode = args[4];
Expand Down

0 comments on commit bdad09e

Please sign in to comment.