diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index b090e3e40063..bdc46d71a77d 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -97,6 +97,13 @@ TVM_DLL Pass LazyGradientInit(); */ TVM_DLL Pass FoldConstant(); +/*! + * \brief Split function with huge number of arguments to smaller pieces. + * + * \return The pass. + */ +TVM_DLL Pass SplitArgs(int max_function_args); + /*! * \brief Fuse operations into expr into seperate functions. * diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cdfd97c780dd..6294e7acea15 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1228,3 +1228,14 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): if missing_op_mode < 0 or missing_op_mode > 2: raise ValueError("Missing op mode is either 0, 1, or 2") return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode) + + +def SplitArgs(max_function_args): + """Split function with huge number of arguments to smaller pieces. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for constant folding. + """ + return _ffi_api.SplitArgs(max_function_args) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index be39a6f6bd25..439674e0468e 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -139,6 +139,10 @@ def max_num_threads(self): def thread_warp_size(self): return int(self.attrs["thread_warp_size"]) + @property + def max_function_args(self): + return int(self.attrs.get("max_function_args", -1)) + @property def device_name(self): return str(self.attrs.get("device", "")) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 23670109e527..ea53c34c793b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -365,6 +365,12 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); + if (targets.size() == 1) { + const auto& target = (*targets.begin()).second; + pass_seqs.push_back( + transform::SplitArgs(target->GetAttr("max_function_args", -1).value())); + } + // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); if (targets.size() == 1) { diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc new file mode 100644 index 000000000000..70d37d822d71 --- /dev/null +++ b/src/relay/transforms/split_args.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file split_args.cc + */ +#include +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +class ArgumentSplitter : public ExprRewriter { + public: + explicit ArgumentSplitter(int max_function_args) + : max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {} + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + if (max_function_args_ < 0) return post; + if (call->op == concat_op_) { + auto op = call->args[0].as(); + const auto param = call->attrs.as(); + int outputsNum = 1; + if (const auto* tuple_type = call->checked_type().as()) { + outputsNum = tuple_type->fields.size(); + } + const int limit = max_function_args_ - outputsNum; + int argsNum = op->fields.size(); + if (argsNum < limit) return post; + int splitNum = argsNum / limit; + splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; + + std::vector splitted(splitNum); + for (int i = 0; i < splitNum; ++i) { + int startIdx = i * limit; + int argsCount = std::min(limit, argsNum - startIdx); + tvm::Array args; + for (int j = 0; j < argsCount; ++j) { + args.push_back(op->fields[j + startIdx]); + } + Tuple tuple(args); + Expr body = MakeConcatenate(tuple, param->axis); + splitted[i] = StopFusion(body); + } + tvm::Array tupleArgs(splitted); + Tuple tuple(tupleArgs); + return MakeConcatenate(tuple, param->axis); + } + return post; + } + + private: + const int max_function_args_; + const Op& concat_op_; +}; + +Expr SplitArgs(const Expr& expr, int max_function_args) { + auto rewriter = ArgumentSplitter(max_function_args); + return PostOrderRewrite(expr, &rewriter); +} + +namespace transform { + +Pass SplitArgs(int max_function_args) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SplitArgs(f, max_function_args)); + }; + return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 71e3529e0d80..b44afec57d5d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -43,7 +43,7 @@ void CodeGenMetal::InitFuncState(const PrimFunc& f) { } } -CodeGenMetal::CodeGenMetal() { +CodeGenMetal::CodeGenMetal(Target target) : target_(target) { decl_stream << "#include \n"; decl_stream << "using namespace metal;\n\n"; decl_stream << "union __TVMArgUnion {\n" @@ -67,6 +67,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // Buffer arguments size_t num_buffer = 0; + int limit = target_->GetAttr("max_function_args").value(); + if (static_cast(f->params.size()) > limit) { + LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " + "buffers in the kernel"; + } for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { Var v = f->params[i]; if (!v.dtype().is_handle()) break; @@ -332,7 +337,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; code << "// Function: " << kv.first->name_hint << std::endl; - CodeGenMetal cg; + CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 614a191907af..9fb8f80303f9 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -35,7 +35,7 @@ namespace codegen { class CodeGenMetal final : public CodeGenC { public: - CodeGenMetal(); + explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); void AddFunction(const PrimFunc& f); // NOLINT(*) @@ -58,6 +58,7 @@ class CodeGenMetal final : public CodeGenC { private: int thread_index_bits_{32}; + Target target_; }; } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index b9d9706773f7..d037b9dfdbdb 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -347,10 +347,15 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) .add_attr_option("thread_warp_size", Integer(1)) .set_default_keys({"opencl", "gpu"}); +// The metal has some limitations on the number of input parameters. This is why attribute +// `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More +// information about this limitation can be found here: +// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .add_attr_option("system-lib") .add_attr_option("max_num_threads", Integer(256)) .add_attr_option("thread_warp_size", Integer(16)) + .add_attr_option("max_function_args", Integer(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) diff --git a/tests/python/relay/test_pass_split_args.py b/tests/python/relay/test_pass_split_args.py new file mode 100644 index 000000000000..2039f464751f --- /dev/null +++ b/tests/python/relay/test_pass_split_args.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing import run_infer_type, create_workload + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, tvm.transform.Pass) + + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + mod = opt_pass(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_split_concat_metal(): + shape = (1, 1, 1, 3) + dtype = "float32" + axis = 1 + inputs = [] + for i in range(100): + inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) + + def before(): + inp = relay.Tuple(inputs) + return relay.op.concatenate(inp, axis) + + def expected(): + limit = tvm.target.Target("metal").max_function_args - 1 # one buffer with output + splitNum = int(len(inputs) / limit) + if len(inputs) % limit > 0: + splitNum += 1 + + splitted = [] + for i in range(splitNum): + startIdx = i * limit + argsCount = min(limit, len(inputs) - startIdx) + args = [] + for j in range(argsCount): + args.append(inputs[j + startIdx]) + t = relay.Tuple(args) + concat = relay.op.concatenate(t, axis) + splitted.append(relay.annotation.stop_fusion(concat)) + inp = relay.Tuple(splitted) + return relay.op.concatenate(inp, axis) + + # the fold constant should work on any context. + res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("metal").max_function_args)) + exp = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(res, exp) + + +def test_split_concat_cuda(): + shape = (1, 1, 1, 3) + dtype = "float32" + axis = 1 + inputs = [] + for i in range(100): + inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) + + def before(): + inp = relay.Tuple(inputs) + return relay.op.concatenate(inp, axis) + + def expected(): + inp = relay.Tuple(inputs) + return relay.op.concatenate(inp, axis) + + # the fold constant should work on any context. + res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("cuda").max_function_args)) + exp = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(res, exp) + + +if __name__ == "__main__": + test_split_concat_metal() + test_split_concat_cuda()