Skip to content

Commit

Permalink
Add getting number of output parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
echuraev committed Jun 24, 2021
1 parent a5cf1ac commit c45c2b5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class ArgumentSplitter : public ExprRewriter {
if (call->op == concat_op_) {
auto op = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
const int limit = max_function_args_ - 1; // one buffer with output
int outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
int argsNum = op->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
Expand Down Expand Up @@ -80,7 +84,7 @@ Pass SplitArgs(int max_function_args) {
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SplitArgs(f, max_function_args));
};
return CreateFunctionPass(pass_func, 1, "SplitArgs", {});
return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs);
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
// Buffer arguments
size_t num_buffer = 0;
int limit = target_->GetAttr<Integer>("max_function_args").value();
if (f->params.size() > limit) {
if (static_cast<int>(f->params.size()) > limit) {
LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
"buffers in the kernel";
}
Expand Down

0 comments on commit c45c2b5

Please sign in to comment.