diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 5f591f1d89ad4..a79bd033c964d 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -120,9 +120,12 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false); /*! * \brief Split function with huge number of arguments to smaller pieces. * + * \param max_function_args Maximum number of function arguments. If it is 0 then SplitArgs won't + * split funciton. + * * \return The pass. */ -TVM_DLL Pass SplitArgs(int max_function_args); +TVM_DLL Pass SplitArgs(uint64_t max_function_args); /*! * \brief Fuse operations into expr into separate functions. diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index d881c4f423338..a40c7d63031ea 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -723,7 +723,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b } /*! - * \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation + * \brief Calculate the output shape of strided_slice, the entry point for Relay type relation * * \param ishape The input tensor shape * \param begin The indices to begin with in the slicing diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index b8af0518b29c5..c162164daace2 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1376,10 +1376,16 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): def SplitArgs(max_function_args): """Split function with huge number of arguments to smaller pieces. + Parameters + ---------- + max_function_args: int + Maximum number of function arguments. If it is 0 then SplitArgs won't split function. + + Returns ------- ret : tvm.transform.Pass - The registered pass for constant folding. + The registered pass. """ return _ffi_api.SplitArgs(max_function_args) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 0c834c5f026ef..09ee0ac20dfe1 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -194,7 +194,7 @@ def max_shared_memory_per_block(self): @property def max_function_args(self): - return int(self.attrs.get("max_function_args", -1)) + return int(self.attrs.get("max_function_args", 0)) @property def vtcm_capacity(self): diff --git a/src/relay/analysis/graph_partitioner.cc b/src/relay/analysis/graph_partitioner.cc index 861fd58d9e5c8..041e0ba6be050 100644 --- a/src/relay/analysis/graph_partitioner.cc +++ b/src/relay/analysis/graph_partitioner.cc @@ -220,6 +220,88 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); } +size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides) { + size_t any_dims = 0; + for (const auto& dim : ttype->shape) { + if (dim.as()) { + any_dims++; + } + } + if (with_strides && any_dims > 0) any_dims += ttype->shape.size(); + return any_dims; +} + +size_t GraphPartitioner::CountArgs_(const tvm::Object* child, const tvm::Object* till_node) { + if (child == till_node) { + // Calculate number of output arguments + if (auto call_node = GetRef(child).as()) { + if (const auto* ttype = call_node->checked_type().as()) { + return CountAdditionalArgs_(ttype) + 1; + } + } + return 1; + } + if (argsMap_.count(child)) { + return argsMap_[child]; + } + size_t args_num = 0; + if (auto call_node = GetRef(child).as()) { + for (auto& it : call_node->args) { + if (it.as() || it.as()) { + args_num += CountArgs_(it.get(), till_node); + } else if (it.as() || it.as()) { + args_num++; + if (const auto* ttype = it->checked_type().as()) { + args_num += CountAdditionalArgs_(ttype); + } + } + } + } else if (GetRef(child).as() || + GetRef(child).as()) { + args_num++; + if (const auto* ttype = + GetRef(child).as()->checked_type().as()) { + args_num += CountAdditionalArgs_(ttype); + } + } else if (auto tuple_node = GetRef(child).as()) { + for (const auto& it : tuple_node->fields) { + args_num++; + args_num += CountArgs_(it.get(), till_node); + } + } + argsMap_[child] = args_num; + return args_num; +} + +size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) { + auto* outputs_list = child->outputs.head; + size_t output_args = 0; + while (outputs_list != nullptr) { + output_args++; + if (auto call_node = GetRef(outputs_list->value.node->ref).as()) { + if (const auto* ttype = call_node->checked_type().as()) { + output_args += CountAdditionalArgs_(ttype, false); + } + } + outputs_list = outputs_list->next; + } + return (max_function_args_ > output_args) ? max_function_args_ - output_args : 0; +} + +size_t GraphPartitioner::CountFusedArgs(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* till_node) { + const tvm::Object* till_node_ref = (till_node != nullptr) ? till_node->ref : nullptr; + auto* outputs_list = child->outputs.head; + size_t res = 1; + while (outputs_list != nullptr) { + size_t output_args = 0; + output_args += CountArgs_(outputs_list->value.node->ref, till_node_ref); + res = std::max(res, output_args); + outputs_list = outputs_list->next; + } + return res; +} + void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { groups_.resize(graph.post_dfs_order.size()); for (size_t nid = 0; nid < groups_.size(); ++nid) { @@ -238,6 +320,8 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // const DominatorTree& post_dom_tree, // int phase) { + IndexedForwardGraph::Node* prev_node = nullptr; + argsMap_.clear(); for (size_t nid = 0; nid < groups_.size(); ++nid) { // the group of current node has been specified already. auto* graph_node = graph.post_dfs_order[nid]; @@ -254,6 +338,15 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // // refuse the fusion if too many ops are going to be fused together if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) continue; + // refuse the fusion if too many arguments are going to be in fused function + auto limit = CountArgsLimit_(graph_node); + if (limit > 0) { + if (CountFusedArgs(graph_node, prev_node) > limit) { + argsMap_.clear(); + prev_node = graph_node; + continue; + } + } if (phase == 2) { // Fuse injective ops into intermediate tuples, if any diff --git a/src/relay/analysis/graph_partitioner.h b/src/relay/analysis/graph_partitioner.h index 9433aafa119d4..0b0bdd40db5b3 100644 --- a/src/relay/analysis/graph_partitioner.h +++ b/src/relay/analysis/graph_partitioner.h @@ -78,7 +78,7 @@ class IndexedForwardGraph { std::vector post_dfs_order; /*! \brief Dump the graph into string. */ - void DebugDump() { + void DebugDump() const { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; @@ -162,8 +162,12 @@ class DominatorTree { */ class GraphPartitioner { public: - explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) - : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} + explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth, + size_t max_function_args) + : arena_(arena), + opt_level_(opt_level), + max_fuse_depth_(max_fuse_depth), + max_function_args_(max_function_args) {} /*! * \brief Group as a union find data structure. */ @@ -205,10 +209,14 @@ class GraphPartitioner { int opt_level_; /*! \brief The maximum number of operations in one fused function */ size_t max_fuse_depth_; + /*! \brief The maximum number of arguments in one fused function */ + size_t max_function_args_; /*! \brief The internal groups. */ std::vector groups_; /*! \brief internal field used for deduplication */ std::unordered_set visited_; + /*! \brief internal field used for hashing arguments number for a node */ + std::unordered_map argsMap_; // Internal implementation of CheckPath template bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); @@ -247,6 +255,9 @@ class GraphPartitioner { void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true); + size_t CountArgs_(const tvm::Object* child, const tvm::Object* till_node); + size_t CountArgsLimit_(const IndexedForwardGraph::Node* child); // Count the number of nodes in a fused subgraph if child is additionally fused. // dom_parent is already known to be a part of the subgraph. @@ -256,6 +267,11 @@ class GraphPartitioner { // is important for correct calculation. size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, IndexedForwardGraph::Node* dom_parent); + // Count the number of arguments in a fused subgraph if child is additionally fused. + // Calculation goes from child node till the till_node. If till_node is a + // nullptr then calculate arguments till the beginning of the graph. + size_t CountFusedArgs(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* till_node = nullptr); // Initialize the groups. void InitGroups(const IndexedForwardGraph& graph); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f92b4692ed532..ed9cc2d11b155 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -337,7 +337,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (config_->optional_homogeneous_target.defined()) { // This pass currently only supports the homogeneous case. pass_seqs.push_back(transform::SplitArgs( - config_->optional_homogeneous_target->GetAttr("max_function_args", -1) + config_->optional_homogeneous_target->GetAttr("max_function_args", 0) .value() .IntValue())); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4ac04761771d1..78b398f140efa 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1081,6 +1081,13 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // Always plan devices so the remaining passes don't need to distinguish homogeneous vs // heterogeneous execution. pass_seqs.push_back(transform::PlanDevices(config_)); + if (config_->optional_homogeneous_target.defined()) { + // This pass currently only supports the homogeneous case. + pass_seqs.push_back(transform::SplitArgs( + config_->optional_homogeneous_target->GetAttr("max_function_args", 0) + .value() + .IntValue())); + } pass_seqs.push_back(transform::FuseOps()); pass_seqs.push_back(transform::AnnotateMemoryScope()); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 9c0d38b115877..ee005aa170523 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -319,9 +319,10 @@ class IndexedForwardGraphCreator : private ExprVisitor { class FuseMutator : private MixedModeMutator { public: - FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) + FuseMutator(int fuse_opt_level, size_t max_fuse_depth, size_t max_function_args, bool link_params) : fuse_opt_level_(fuse_opt_level), max_fuse_depth_(max_fuse_depth), + max_function_args_(max_function_args), link_params_(link_params) {} // Run the transform @@ -334,7 +335,8 @@ class FuseMutator : private MixedModeMutator { Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { // setup the group map. auto graph = IndexedForwardGraphCreator::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); + auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_) + .Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { ICHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; @@ -347,6 +349,7 @@ class FuseMutator : private MixedModeMutator { private: int fuse_opt_level_; size_t max_fuse_depth_; + size_t max_function_args_; bool link_params_; using MixedModeMutator::VisitExpr_; @@ -548,9 +551,10 @@ class FuseMutator : private MixedModeMutator { } }; -Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, bool link_params, - const IRModule& module) { - return FuseMutator(fuse_opt_level, max_fuse_depth, link_params).Transform(expr); +Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, size_t max_function_args, + bool link_params, const IRModule& module) { + return FuseMutator(fuse_opt_level, max_fuse_depth, max_function_args, link_params) + .Transform(expr); } namespace transform { @@ -567,8 +571,13 @@ Pass FuseOps(int fuse_opt_level) { link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value(); int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps)); - return Downcast( - FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), link_params, m)); + auto target = Target::Current(); + size_t max_function_args = + (target.defined()) + ? target->GetAttr("max_function_args", Integer(0)).value().IntValue() + : 0; + return Downcast(FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), + max_function_args, link_params, m)); }; return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"}); } diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index 00b9a3be3b2ee..6efab8adf7109 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -31,58 +31,101 @@ namespace relay { class ArgumentSplitter : public ExprRewriter { public: - explicit ArgumentSplitter(int max_function_args) + explicit ArgumentSplitter(size_t max_function_args) : max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {} + Expr ConcatSplitter(const TupleNode* tuple_node, const tvm::Array& args, int axis, + size_t limit) { + relay::Expr lastExpr; + tvm::Array new_args; + size_t added_args = 0; + for (const auto& it : args) { + size_t curr_args = 1; + if (const auto* ttype = it->checked_type().as()) { + ICHECK(additional_args_cache_.count(ttype)); + curr_args += additional_args_cache_[ttype]; + } + if (added_args + curr_args > limit) { + Tuple new_tuple = WithFields(GetRef(tuple_node), new_args); + Expr body = MakeConcatenate(new_tuple, axis); + lastExpr = StopFusion(body); + new_args.clear(); + new_args.push_back(lastExpr); + added_args = 1; + } + added_args += curr_args; + new_args.push_back(it); + } + Tuple new_tuple = WithFields(GetRef(tuple_node), new_args); + Expr body = MakeConcatenate(new_tuple, axis); + lastExpr = StopFusion(body); + return lastExpr; + } + + // In case of dynamic shape in tensor, size of any_dims and strides are passed as function args + size_t CalculateNumberOfAdditionalArgs_(const TensorTypeNode* arg, bool isOutput = false) { + size_t num = 0; + for (const auto& dim : arg->shape) { + if (dim.as()) { + num++; + } + } + // In case of dynamic shape also strides will be passed to function + // as arguments. Number of strides equals to the rank of the tensor. + if (num > 0 && isOutput) + return arg->shape.size(); + else if (num > 0) + num += arg->shape.size(); + return num; + } + Expr Rewrite_(const CallNode* call, const Expr& post) final { - if (max_function_args_ < 0) return post; + if (max_function_args_ == 0) return post; if (call->op == concat_op_) { auto tuple_node = call->args[0].as(); const auto param = call->attrs.as(); - int outputsNum = 1; + size_t outputsNum = 1; if (const auto* tuple_type = call->checked_type().as()) { outputsNum = tuple_type->fields.size(); + for (const auto& it : tuple_type->fields) { + if (const auto* ttype = it.as()) { + outputsNum += CalculateNumberOfAdditionalArgs_(ttype, true); + } + } + } else if (const auto* ttype = call->checked_type().as()) { + outputsNum += CalculateNumberOfAdditionalArgs_(ttype, true); } - const int limit = max_function_args_ - outputsNum; - int argsNum = tuple_node->fields.size(); - if (argsNum < limit) return post; - int splitNum = argsNum / limit; - splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; - - std::vector splitted(splitNum); - for (int i = 0; i < splitNum; ++i) { - int startIdx = i * limit; - int argsCount = std::min(limit, argsNum - startIdx); - tvm::Array args; - args.reserve(argsCount); + CHECK_GT(max_function_args_, outputsNum); + size_t limit = max_function_args_ - outputsNum; - for (int j = 0; j < argsCount; ++j) { - args.push_back(tuple_node->fields[j + startIdx]); + size_t argsNum = tuple_node->fields.size(); + for (const auto& it : tuple_node->fields) { + if (const auto* ttype = it->checked_type().as()) { + size_t any_dims = CalculateNumberOfAdditionalArgs_(ttype); + argsNum += any_dims; + additional_args_cache_[ttype] = any_dims; } - Tuple new_tuple = WithFields(GetRef(tuple_node), args); - Expr body = MakeConcatenate(new_tuple, param->axis); - splitted[i] = StopFusion(body); } - tvm::Array tuple_args(splitted); - Tuple new_tuple = WithFields(GetRef(tuple_node), tuple_args); - return MakeConcatenate(new_tuple, param->axis); + if (argsNum < limit) return post; + return ConcatSplitter(tuple_node, tuple_node->fields, param->axis, limit); } return post; } private: - const int max_function_args_; + const size_t max_function_args_; const Op& concat_op_; + std::unordered_map additional_args_cache_; }; -Expr SplitArgs(const Expr& expr, int max_function_args) { +Expr SplitArgs(const Expr& expr, size_t max_function_args) { auto rewriter = ArgumentSplitter(max_function_args); return PostOrderRewrite(expr, &rewriter); } namespace transform { -Pass SplitArgs(int max_function_args) { +Pass SplitArgs(uint64_t max_function_args) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { auto r = Downcast(SplitArgs(f, max_function_args)); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b7105e4bcdfc1..b8c30691e21fa 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -68,8 +68,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // Buffer arguments size_t num_buffer = 0; - int limit = target_->GetAttr("max_function_args").value().IntValue(); - if (static_cast(f->params.size()) > limit) { + size_t limit = target_->GetAttr("max_function_args").value().IntValue(); + if (f->params.size() > limit) { LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " "buffers in the kernel"; } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 174ae6d43fc81..bb4d9519ba627 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -365,6 +365,11 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) .add_attr_option("max_num_threads", Integer(256)) .add_attr_option("thread_warp_size", Integer(1)) .add_attr_option("texture_spatial_limit", Integer(16384)) + // Faced that Qualcomm OpenCL runtime was crashed without any error message in + // case when the number of kernel arguments was pretty big. OpenCL doesn't + // specify any limitations on the number of kernel arguments. max_function_args + // equals to 128 looks like a reasonable number of kernel arguments. + .add_attr_option("max_function_args", Integer(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 06c93fbc5549e..7041a6c12939a 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -624,6 +624,12 @@ def expected(n, max_fused_ops): assert tvm.ir.structural_equal(zz, after) + with tvm.target.Target("opencl"): + with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): + cl_zz = run_opt_pass(z, transform.FuseOps()) + + assert tvm.ir.structural_equal(cl_zz, after) + link_params = tvm.testing.parameter(False, True) @@ -828,5 +834,123 @@ def expected(): tvm.testing.assert_allclose(result, ref, rtol=1e-4, atol=1e-4) +target_name = tvm.testing.parameter("opencl", "metal", "cuda") +shape_type = tvm.testing.parameter("dynamic", "static") + + +def test_fuse_max_num_args(target_name, shape_type): + if shape_type == "dynamic": + shape = (tvm.tir.Any(), 20) + number_of_any_dims = 1 + else: + shape = (10, 20) + number_of_any_dims = 0 + ndims = len(shape) + ops_num = 300 + + def _base_func(name): + x = relay.var(name, shape=shape) + y = relay.add(x, relay.const(1, "float32")) + w = relay.exp(y) + return x, w + + def before(n): + inp = [] + out = [] + for i in range(n): + x, w = _base_func(f"x{i}") + inp.append(x) + out.append(w) + w = out[0] + for i in range(len(out) - 1): + w = relay.add(w, out[i + 1]) + return relay.Function(inp, w) + + def after(n): + def create_base_funcs_sum(args_number, limit, prev=None): + added_args = 0 + inputs = [] + input_vars = [] + additional_arg = 0 if prev is None else 1 + res = None + for i in range(args_number + additional_arg): + inp, out = _base_func(f"p{i}") + if i == 1 and prev is not None: + res = relay.add(inp, res) + input_vars.append(relay.var(f"x{i}", shape=shape)) + inputs.append(inp) + added_args += 1 + number_of_any_dims + if number_of_any_dims > 0: + added_args += ndims + continue + + curr_args = 1 + number_of_any_dims + if number_of_any_dims > 0: + curr_args += ndims + + if added_args + curr_args > limit: + f = relay.Function(inputs, res) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return i - additional_arg, input_vars, f + + input_vars.append(relay.var(f"x{i}", shape=shape)) + inputs.append(inp) + if res is None: + res = out + else: + res = relay.add(res, out) + added_args += curr_args + f = relay.Function(inputs, res) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + return args_number, input_vars, f + + def create_accum_func(args_limit): + out = None + inputs = [] + if args_limit == 0: + for i in range(n): + inputs.append(relay.var(f"x{i}", shape=shape)) + f = before(n) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + out = relay.Call(f, inputs) + return relay.Function(inputs, out) + + added_args = 0 + while added_args < n: + # When out is not none that means that one additional argument + # will be used for the result of previous fusing + args_number = n - added_args + a_num, inp, func = create_base_funcs_sum(args_number, args_limit, out) + added_args += a_num + inputs.append(inp[0]) + if len(inp) > 1: + if out is not None: + inp[1] = out + else: + inputs.append(inp[1]) + else: + if out is not None: + inp.append(out) + if len(inp) > 2: + inputs.extend(inp[2:]) + out = relay.Call(func, inp) + return relay.Function(inputs, out) + + args_limit = tvm.target.Target.current().max_function_args - ( + 1 + number_of_any_dims + ) # one buffer with output + args_limit = max(args_limit, 0) + return create_accum_func(args_limit) + + max_fused_ops = ops_num * 5 + with tvm.target.Target(target_name): + with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}): + fused = run_opt_pass(before(ops_num), transform.FuseOps()) + + expected = run_opt_pass(after(ops_num), transform.InferType()) + + assert tvm.ir.structural_equal(fused, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relay/test_pass_split_args.py b/tests/python/relay/test_pass_split_args.py index 2039f464751f6..83f224d25e671 100644 --- a/tests/python/relay/test_pass_split_args.py +++ b/tests/python/relay/test_pass_split_args.py @@ -22,6 +22,10 @@ from tvm.relay.testing import run_infer_type, create_workload +target_name = tvm.testing.parameter("opencl", "metal", "cuda") +shape_type = tvm.testing.parameter("dynamic", "static") + + def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -32,65 +36,70 @@ def run_opt_pass(expr, opt_pass): return entry if isinstance(expr, relay.Function) else entry.body -def test_split_concat_metal(): - shape = (1, 1, 1, 3) - dtype = "float32" - axis = 1 - inputs = [] - for i in range(100): - inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) - - def before(): - inp = relay.Tuple(inputs) - return relay.op.concatenate(inp, axis) - - def expected(): - limit = tvm.target.Target("metal").max_function_args - 1 # one buffer with output - splitNum = int(len(inputs) / limit) - if len(inputs) % limit > 0: - splitNum += 1 - - splitted = [] - for i in range(splitNum): - startIdx = i * limit - argsCount = min(limit, len(inputs) - startIdx) - args = [] - for j in range(argsCount): - args.append(inputs[j + startIdx]) - t = relay.Tuple(args) - concat = relay.op.concatenate(t, axis) - splitted.append(relay.annotation.stop_fusion(concat)) - inp = relay.Tuple(splitted) - return relay.op.concatenate(inp, axis) - - # the fold constant should work on any context. - res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("metal").max_function_args)) - exp = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(res, exp) - - -def test_split_concat_cuda(): - shape = (1, 1, 1, 3) +def test_split_concat(target_name, shape_type): + if shape_type == "dynamic": + shape = (tvm.tir.Any(), 1, 1, 3) + number_of_any_dims = 1 + else: + shape = (1, 1, 1, 3) + number_of_any_dims = 0 + ndims = len(shape) dtype = "float32" axis = 1 + tensors_num = 300 inputs = [] - for i in range(100): + for i in range(tensors_num): inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) def before(): inp = relay.Tuple(inputs) return relay.op.concatenate(inp, axis) - def expected(): - inp = relay.Tuple(inputs) - return relay.op.concatenate(inp, axis) + def expected(limit): + if limit == 0: + return before() + limit = limit - 1 # one buffer with output + if number_of_any_dims > 0: + limit -= ndims + + last_op = None + new_args = [] + added_args = 0 + num_inputs = 0 + for inp in inputs: + curr_args = 1 + number_of_any_dims + if number_of_any_dims > 0: + curr_args += ndims + num_inputs += curr_args + if added_args + curr_args > limit: + t = relay.Tuple(new_args) + concat = relay.op.concatenate(t, axis) + last_op = relay.annotation.stop_fusion(concat) + new_args = [last_op] + added_args = 1 + added_args += curr_args + new_args.append(inp) + t = relay.Tuple(new_args) + concat = relay.op.concatenate(t, axis) + last_op = relay.annotation.stop_fusion(concat) + + if num_inputs < limit: + return before() + + return last_op # the fold constant should work on any context. - res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("cuda").max_function_args)) - exp = run_opt_pass(expected(), transform.InferType()) + res = run_opt_pass( + before(), + transform.SplitArgs(tvm.target.Target(target_name).max_function_args), + ) + exp = run_opt_pass( + expected(tvm.target.Target(target_name).max_function_args), transform.InferType() + ) + limit = tvm.target.Target(target_name).max_function_args + exp = run_opt_pass(expected(limit), transform.InferType()) assert tvm.ir.structural_equal(res, exp) if __name__ == "__main__": - test_split_concat_metal() - test_split_concat_cuda() + tvm.testing.main() diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 83612b7f59798..b03999c9fa6a1 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te +from tvm import te, relay import tvm.testing import re +import pytest +import numpy as np target = "opencl" @@ -217,5 +219,194 @@ def _check(target, n, dtype): _check(target, 32, "float32") +shape_type = tvm.testing.parameter("dynamic", "static") +executor_type = tvm.testing.parameter("ge", "vm") + + +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl +def test_opencl_args_split(executor_type, shape_type): + def _get_model(): + if shape_type == "dynamic": + shape = (tvm.tir.Any(), 1, 1, 3) + else: + shape = (1, 1, 1, 3) + shape_np = (1, 1, 1, 3) + dtype = "float32" + axis = 1 + tensors_num = 300 + inputs = [] + inputs_np = {} + for i in range(tensors_num): + inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) + inputs_np[f"p{i}"] = np.random.uniform(size=shape_np).astype(dtype) + + inp = relay.Tuple(inputs) + concat = relay.op.concatenate(inp, axis) + return inputs_np, relay.Function(inputs, concat) + + def ref_impl(inputs): + axis = 1 + return np.concatenate(tuple(inputs), axis=axis) + + def get_maximum_kernel_args(source): + def get_kernel_args(source): + import re + + p = re.compile(r"__kernel void .+\((.*)\)") + args = p.findall(source) + return args + + args = get_kernel_args(source) + max_args = len(args[0].split(",")) + for arg_line in args: + max_args = max(max_args, len(arg_line.split(","))) + return max_args + + def validate(): + from tvm.contrib import graph_executor + from tvm.runtime.vm import VirtualMachine + + input_dict, model = _get_model() + if executor_type == "ge": + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(model, target_host="llvm", target=target) + ocl_lib = lib.get_lib() + else: + module = tvm.IRModule({}) + module["main"] = model + with tvm.transform.PassContext(opt_level=1): + lib = relay.vm.compile(module, target=target, target_host="llvm") + ocl_lib = lib.module.imported_modules[0] + opencl_modules = list( + filter(lambda mod: mod.type_key == "opencl", ocl_lib.imported_modules) + ) + assembly = opencl_modules[0].get_source() + with tvm.target.Target(target): + limit = tvm.target.Target.current().max_function_args + max_num = get_maximum_kernel_args(assembly) + assert max_num <= limit + + dev = tvm.cl() + if executor_type == "ge": + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input(**input_dict) + module.run() + tvm_out = module.get_output(0) + else: + vm = VirtualMachine(lib, dev, "naive") + data = {} + for k, v in input_dict.items(): + data[k] = tvm.nd.array(v, dev) + vm.set_input("main", **data) + vm.invoke_stateful("main") + tvm_out = vm.get_outputs()[0] + + np_result = ref_impl(list(input_dict.values())) + np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + if executor_type == "ge" and shape_type == "dynamic": + pytest.skip() + validate() + + +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl +def test_opencl_fuse_max_args(executor_type, shape_type): + if shape_type == "dynamic": + shape = (tvm.tir.Any(), 20) + else: + shape = (1, 20) + shape_np = (1, 20) + ops_num = 300 + dtype = "float32" + + def _base_func(name): + x = relay.var(name, shape=shape) + y = relay.add(x, relay.const(1, "float32")) + w = relay.exp(y) + return x, w + + def _get_model(): + inp = [] + inputs_np = {} + out = [] + for i in range(ops_num): + x, w = _base_func(f"x{i}") + inputs_np[f"x{i}"] = np.random.uniform(size=shape_np).astype(dtype) + inp.append(x) + out.append(w) + w = out[0] + for i in range(len(out) - 1): + w = relay.add(w, out[i + 1]) + return inputs_np, relay.Function(inp, w) + + def ref_impl(inputs): + w = np.exp(inputs[0] + 1) + for i in range(len(inputs) - 1): + w = w + np.exp(inputs[i + 1] + 1) + return w + + def get_maximum_kernel_args(source): + def get_kernel_args(source): + import re + + p = re.compile(r"__kernel void .+\((.*)\)") + args = p.findall(source) + return args + + args = get_kernel_args(source) + max_args = len(args[0].split(",")) + for arg_line in args: + max_args = max(max_args, len(arg_line.split(","))) + return max_args + + def validate(): + from tvm.contrib import graph_executor + from tvm.runtime.vm import VirtualMachine + + input_dict, model = _get_model() + if executor_type == "ge": + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(model, target_host="llvm", target=target) + ocl_lib = lib.get_lib() + else: + module = tvm.IRModule({}) + module["main"] = model + with tvm.transform.PassContext(opt_level=1): + lib = relay.vm.compile(module, target=target, target_host="llvm") + ocl_lib = lib.module.imported_modules[0] + opencl_modules = list( + filter(lambda mod: mod.type_key == "opencl", ocl_lib.imported_modules) + ) + assembly = opencl_modules[0].get_source() + with tvm.target.Target(target): + limit = tvm.target.Target.current().max_function_args + max_num = get_maximum_kernel_args(assembly) + assert max_num <= limit + + dev = tvm.cl() + if executor_type == "ge": + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input(**input_dict) + module.run() + tvm_out = module.get_output(0) + else: + vm = VirtualMachine(lib, dev, "naive") + data = {} + for k, v in input_dict.items(): + data[k] = tvm.nd.array(v, dev) + vm.set_input("main", **data) + vm.invoke_stateful("main") + tvm_out = vm.get_outputs()[0] + + np_result = ref_impl(list(input_dict.values())) + np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + if executor_type == "ge" and shape_type == "dynamic": + pytest.skip() + validate() + + if __name__ == "__main__": tvm.testing.main()