From c9028c74d6db30795152539412b285e147813f17 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 24 Jun 2021 08:15:18 +0300 Subject: [PATCH] Add getting number of output parameters --- src/relay/transforms/split_args.cc | 8 ++++++-- src/target/source/codegen_metal.cc | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index cdd596be37a52..dadfb51a3ea68 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -38,7 +38,11 @@ class ArgumentSplitter : public ExprRewriter { if (call->op == concat_op_) { auto op = call->args[0].as(); const auto param = call->attrs.as(); - const int limit = max_function_args_ - 1; // one buffer with output + int outputsNum = 1; + if (const auto* tuple_type = call->checked_type().as()) { + outputsNum = tuple_type->fields.size(); + } + const int limit = max_function_args_ - outputsNum; int argsNum = op->fields.size(); if (argsNum < limit) return post; int splitNum = argsNum / limit; @@ -80,7 +84,7 @@ Pass SplitArgs(int max_function_args) { [=](Function f, IRModule m, PassContext pc) { return Downcast(SplitArgs(f, max_function_args)); }; - return CreateFunctionPass(pass_func, 1, "SplitArgs", {}); + return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 3c5ff89c6f1d5..b44afec57d5d5 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -68,7 +68,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // Buffer arguments size_t num_buffer = 0; int limit = target_->GetAttr("max_function_args").value(); - if (f->params.size() > limit) { + if (static_cast(f->params.size()) > limit) { LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " "buffers in the kernel"; }