From d4b80365708f0d82a80e9bc563284bdd1f3180a7 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Tue, 18 Jul 2023 10:36:28 +0300 Subject: [PATCH] [Relay] Introduce arguments limit to FuseOps pass In PR #8313 a parameter `max_function_args` was introduced. It leads to limit number of function argument and in case when this value is exceeded then concatenation layer is split to a several concat operations. I faced a problem on Adreno GPU that for kernel with big number of arguments the enqueueNDRange was crashed without any errors. The problem appeared because of the huge number of arguments. But in this case not only concat layer was a root cause of the problem. Also after fusing several operations the final functions had a big number of arguments. As it was discussed in #8313, adding a limitation on the number of function arguments to the FuseOps pass might be a good improvement. In this PR I introduced such mechanism for limitation number of function arguments for FuseOps pass and add an arguments limit to OpenCL devices at 128 parameters. The idea of current approach is calculate the number of arguments for each node in fusing algorithm and in case then the number of function arguments exceeds the limit, specified by `max_function_args`, then the fusing should be stopped. In case when node has several inputs and for some of the inputs the number of arguments wasn't computed, then we postpone fusing for this node and will try fuse this node later when the number of arguments will be computed for all inputs. This approach with postponed fusing helps to avoid additional computations during compilation. Additionally, case of dynamic shapes should be handled. In case of dynamic shape, function arguments also included sizes of dynamic dimension and strides. The number of strides can be computed by calculating number of tensor dimensions (the number of strides equals to the rank of the tensor). The number of additional parameters with sizes of dynamic dimensions can be calculated by computing number of dynamic dimensions. --- include/tvm/relay/transform.h | 2 +- python/tvm/relay/op/tensor.py | 9 +- src/relay/analysis/graph_partitioner.cc | 160 +++++++----- src/relay/analysis/graph_partitioner.h | 34 ++- src/relay/transforms/split_args.cc | 12 +- tests/python/relay/test_pass_fuse_ops.py | 59 ++--- tests/python/relay/test_pass_split_args.py | 23 +- .../unittest/test_target_codegen_opencl.py | 243 ++++++++++-------- 8 files changed, 307 insertions(+), 235 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a79bd033c964d..da4d05f0e63e7 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -121,7 +121,7 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false); * \brief Split function with huge number of arguments to smaller pieces. * * \param max_function_args Maximum number of function arguments. If it is 0 then SplitArgs won't - * split funciton. + * split function. * * \return The pass. */ diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 6b488719eb84c..26caa4584c79b 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -23,7 +23,7 @@ from . import _make from .dyn import _make as _dyn_make -from ..expr import Tuple, Expr, Constant +from ..expr import Tuple, Expr, Constant, Call from . import op as reg @@ -1141,12 +1141,15 @@ def concatenate(data, axis): result: relay.Expr The concatenated tensor. """ - data = list(data) + if not isinstance(data, Call): + data = list(data) if not data: raise ValueError("relay.concatenate requires data to be non-empty.") + if not isinstance(data, Call): + data = Tuple(data) if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") - return _make.concatenate(Tuple(data), axis) + return _make.concatenate(data, axis) def einsum(data, equation): diff --git a/src/relay/analysis/graph_partitioner.cc b/src/relay/analysis/graph_partitioner.cc index 041e0ba6be050..46dd99a25ab0d 100644 --- a/src/relay/analysis/graph_partitioner.cc +++ b/src/relay/analysis/graph_partitioner.cc @@ -169,6 +169,7 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { if (child == parent) return; // update the number of nodes of the parent group parent->num_nodes += child->num_nodes; + parent->args_num += child->args_num; child->parent = parent; // update anchor ref and pattern if (child->anchor_ref != nullptr) { @@ -193,6 +194,10 @@ void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwar } void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + if (postpone_node_ != nullptr) { + postponed_fusing_map_.insert({postpone_node_, src}); + return; + } Group* target = groups_[sink->index]; visited_.clear(); ICHECK(src != sink); @@ -220,6 +225,45 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); } +size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src, + const IndexedForwardGraph& graph, bool update_postpone) { + std::unordered_set visited_groups; + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + auto sum = gnode->args_num; + visited_groups.insert(gnode->FindRoot()); + auto calcArgs = [this, src, &graph, &visited_groups, + update_postpone](const relay::Expr& arg) -> size_t { + if (arg.as()) return 0; + auto* node = graph.node_map.at(arg.get()); + Group* prev_group = groups_[node->index]->FindRoot(); + if (visited_groups.count(prev_group) == 0) { + visited_groups.insert(prev_group); + if (prev_group->args_num > 0) { + // Get number of arguments from group + return prev_group->args_num; + } else if (update_postpone) { + // Update pointer to node which should be postponed for deferred fusing + postpone_node_ = src; + } else { + // Calculate number of arguments for the node which wasn't processed before + return CountArgs_(node, graph, update_postpone); + } + } + return 0; + }; + if (auto call_node = GetRef(src->ref).as()) { + for (auto& it : call_node->args) { + sum += calcArgs(it); + } + } else if (auto tuple_node = GetRef(src->ref).as()) { + for (auto& it : tuple_node->fields) { + sum += calcArgs(it); + } + } + return sum; +} + size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides) { size_t any_dims = 0; for (const auto& dim : ttype->shape) { @@ -231,48 +275,6 @@ size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool return any_dims; } -size_t GraphPartitioner::CountArgs_(const tvm::Object* child, const tvm::Object* till_node) { - if (child == till_node) { - // Calculate number of output arguments - if (auto call_node = GetRef(child).as()) { - if (const auto* ttype = call_node->checked_type().as()) { - return CountAdditionalArgs_(ttype) + 1; - } - } - return 1; - } - if (argsMap_.count(child)) { - return argsMap_[child]; - } - size_t args_num = 0; - if (auto call_node = GetRef(child).as()) { - for (auto& it : call_node->args) { - if (it.as() || it.as()) { - args_num += CountArgs_(it.get(), till_node); - } else if (it.as() || it.as()) { - args_num++; - if (const auto* ttype = it->checked_type().as()) { - args_num += CountAdditionalArgs_(ttype); - } - } - } - } else if (GetRef(child).as() || - GetRef(child).as()) { - args_num++; - if (const auto* ttype = - GetRef(child).as()->checked_type().as()) { - args_num += CountAdditionalArgs_(ttype); - } - } else if (auto tuple_node = GetRef(child).as()) { - for (const auto& it : tuple_node->fields) { - args_num++; - args_num += CountArgs_(it.get(), till_node); - } - } - argsMap_[child] = args_num; - return args_num; -} - size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) { auto* outputs_list = child->outputs.head; size_t output_args = 0; @@ -288,21 +290,47 @@ size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) return (max_function_args_ > output_args) ? max_function_args_ - output_args : 0; } -size_t GraphPartitioner::CountFusedArgs(IndexedForwardGraph::Node* child, - IndexedForwardGraph::Node* till_node) { - const tvm::Object* till_node_ref = (till_node != nullptr) ? till_node->ref : nullptr; +size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph, + IndexedForwardGraph::Node* child) { + size_t args_num = 0; auto* outputs_list = child->outputs.head; - size_t res = 1; while (outputs_list != nullptr) { - size_t output_args = 0; - output_args += CountArgs_(outputs_list->value.node->ref, till_node_ref); - res = std::max(res, output_args); + args_num = std::max(args_num, CountArgs_(outputs_list->value.node, graph)); outputs_list = outputs_list->next; } - return res; + return args_num; } void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { + auto args_counter = [this](const tvm::Object* obj) { + size_t args_num = 0; + if (auto call_node = GetRef(obj).as()) { + for (auto& it : call_node->args) { + if (it.as() || it.as()) { + args_num++; + if (const auto* ttype = it.as()->checked_type().as()) { + args_num += CountAdditionalArgs_(ttype); + } + } + } + } else if (auto tuple_node = GetRef(obj).as()) { + for (auto& it : tuple_node->fields) { + if (it.as() || it.as()) { + args_num++; + if (const auto* ttype = it.as()->checked_type().as()) { + args_num += CountAdditionalArgs_(ttype); + } + } + } + } else if (GetRef(obj).as()) { + args_num++; + if (const auto* ttype = + GetRef(obj).as()->checked_type().as()) { + args_num += CountAdditionalArgs_(ttype); + } + } + return args_num; + }; groups_.resize(graph.post_dfs_order.size()); for (size_t nid = 0; nid < groups_.size(); ++nid) { const auto* graph_node = graph.post_dfs_order[nid]; @@ -313,6 +341,7 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { if (group_node->pattern == relay::kOutEWiseFusable) { group_node->anchor_ref = graph_node->ref; } + group_node->args_num = args_counter(graph_node->ref); groups_[nid] = group_node; } } @@ -320,14 +349,25 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // const DominatorTree& post_dom_tree, // int phase) { - IndexedForwardGraph::Node* prev_node = nullptr; - argsMap_.clear(); for (size_t nid = 0; nid < groups_.size(); ++nid) { // the group of current node has been specified already. auto* graph_node = graph.post_dfs_order[nid]; auto* dom_node = post_dom_tree.nodes[nid]; Group* group_node = groups_[nid]; ICHECK(group_node != nullptr); + postpone_node_ = nullptr; + if (postponed_fusing_map_.count(graph_node)) { + auto range = postponed_fusing_map_.equal_range(graph_node); + for (auto it = range.first; it != range.second; ++it) { + if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) { + auto* src = it->second; + auto* snode = post_dom_tree.nodes[src->index]->parent->gnode; + if (groups_[snode->index]->anchor_ref != nullptr) continue; + CommitFuse(src, snode); + } + } + postponed_fusing_map_.erase(graph_node); + } // no actions for opaque nodes if (group_node->pattern == kOpaque) continue; // no actions needed if the current node have no dominator @@ -339,12 +379,16 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) continue; // refuse the fusion if too many arguments are going to be in fused function - auto limit = CountArgsLimit_(graph_node); - if (limit > 0) { - if (CountFusedArgs(graph_node, prev_node) > limit) { - argsMap_.clear(); - prev_node = graph_node; - continue; + if (max_function_args_ > 0) { + auto limit = CountArgsLimit_(graph_node); + if (limit > 0) { + auto args = CountFusedArgs(graph, graph_node); + //std::cout << "args: " << args << ", limit: " << limit << std::endl; + if (args > limit) { + //std::cout << " >>> args: " << args << ", limit: " << limit << std::endl; + //Dump(graph_node->ref); + continue; + } } } diff --git a/src/relay/analysis/graph_partitioner.h b/src/relay/analysis/graph_partitioner.h index 0b0bdd40db5b3..a391d1464da66 100644 --- a/src/relay/analysis/graph_partitioner.h +++ b/src/relay/analysis/graph_partitioner.h @@ -187,6 +187,10 @@ class GraphPartitioner { * \brief The number of nodes belonging to this group */ uint32_t num_nodes{1}; + /*! + * \brief The number of function arguments belonging to this group + */ + size_t args_num{0}; /*! \brief Optional attributes to annotate the grouped function. */ runtime::Map attrs; @@ -215,8 +219,15 @@ class GraphPartitioner { std::vector groups_; /*! \brief internal field used for deduplication */ std::unordered_set visited_; - /*! \brief internal field used for hashing arguments number for a node */ - std::unordered_map argsMap_; + /*! \brief The map with nodes which were postponed for fusing. */ + std::unordered_multimap + postponed_fusing_map_; + /*! + * \brief Fusing of this node should be postponed till all child nodes will be evaluated. + * It is used to calculate number of arguments which will be passed to this node in + * generated function. + */ + const IndexedForwardGraph::Node* postpone_node_{nullptr}; // Internal implementation of CheckPath template bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); @@ -255,8 +266,22 @@ class GraphPartitioner { void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + // Count the number of additional arguments. In case of dynamic shape, + // generated function takes several additional arguments, such as size of + // dynamic dimension and strides. + // This function calculates number of such additional arguments. size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true); - size_t CountArgs_(const tvm::Object* child, const tvm::Object* till_node); + // Calculate the number of arguments for the node. + size_t CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph, + bool update_postpone = true); + // Count actual limit of arguments for a generated function. + // max_function_args_ specifies the number of maximum function arguments. But + // usually, output tensors also passed to the function as arguments. + // Additionally, in case of dynamic shape, it is necessary to take into + // account the number of parameters which specifies the size of dynamic + // dimension. + // This function computes limit of arguments by the following formula: + // limit = max_function_args_ - output_args_count size_t CountArgsLimit_(const IndexedForwardGraph::Node* child); // Count the number of nodes in a fused subgraph if child is additionally fused. @@ -270,8 +295,7 @@ class GraphPartitioner { // Count the number of arguments in a fused subgraph if child is additionally fused. // Calculation goes from child node till the till_node. If till_node is a // nullptr then calculate arguments till the beginning of the graph. - size_t CountFusedArgs(IndexedForwardGraph::Node* child, - IndexedForwardGraph::Node* till_node = nullptr); + size_t CountFusedArgs(const IndexedForwardGraph& graph, IndexedForwardGraph::Node* child); // Initialize the groups. void InitGroups(const IndexedForwardGraph& graph); diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index 6efab8adf7109..6ef404ee814d9 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -36,7 +36,6 @@ class ArgumentSplitter : public ExprRewriter { Expr ConcatSplitter(const TupleNode* tuple_node, const tvm::Array& args, int axis, size_t limit) { - relay::Expr lastExpr; tvm::Array new_args; size_t added_args = 0; for (const auto& it : args) { @@ -47,18 +46,18 @@ class ArgumentSplitter : public ExprRewriter { } if (added_args + curr_args > limit) { Tuple new_tuple = WithFields(GetRef(tuple_node), new_args); - Expr body = MakeConcatenate(new_tuple, axis); - lastExpr = StopFusion(body); + Expr stop = StopFusion(new_tuple); + Expr lastExpr = MakeConcatenate(stop, axis); new_args.clear(); new_args.push_back(lastExpr); - added_args = 1; + added_args = curr_args; } added_args += curr_args; new_args.push_back(it); } Tuple new_tuple = WithFields(GetRef(tuple_node), new_args); - Expr body = MakeConcatenate(new_tuple, axis); - lastExpr = StopFusion(body); + Expr stop = StopFusion(new_tuple); + Expr lastExpr = MakeConcatenate(stop, axis); return lastExpr; } @@ -83,6 +82,7 @@ class ArgumentSplitter : public ExprRewriter { if (max_function_args_ == 0) return post; if (call->op == concat_op_) { auto tuple_node = call->args[0].as(); + if (tuple_node == nullptr) return post; const auto param = call->attrs.as(); size_t outputsNum = 1; if (const auto* tuple_type = call->checked_type().as()) { diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 7041a6c12939a..4dd05ee782116 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -850,8 +850,7 @@ def test_fuse_max_num_args(target_name, shape_type): def _base_func(name): x = relay.var(name, shape=shape) - y = relay.add(x, relay.const(1, "float32")) - w = relay.exp(y) + w = relay.add(x, relay.const(1, "float32")) return x, w def before(n): @@ -867,22 +866,14 @@ def before(n): return relay.Function(inp, w) def after(n): - def create_base_funcs_sum(args_number, limit, prev=None): + def create_fused_func(limit): added_args = 0 inputs = [] input_vars = [] - additional_arg = 0 if prev is None else 1 res = None - for i in range(args_number + additional_arg): + i = 0 + while added_args < limit: inp, out = _base_func(f"p{i}") - if i == 1 and prev is not None: - res = relay.add(inp, res) - input_vars.append(relay.var(f"x{i}", shape=shape)) - inputs.append(inp) - added_args += 1 + number_of_any_dims - if number_of_any_dims > 0: - added_args += ndims - continue curr_args = 1 + number_of_any_dims if number_of_any_dims > 0: @@ -891,7 +882,7 @@ def create_base_funcs_sum(args_number, limit, prev=None): if added_args + curr_args > limit: f = relay.Function(inputs, res) f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - return i - additional_arg, input_vars, f + return i, input_vars, f input_vars.append(relay.var(f"x{i}", shape=shape)) inputs.append(inp) @@ -900,9 +891,10 @@ def create_base_funcs_sum(args_number, limit, prev=None): else: res = relay.add(res, out) added_args += curr_args + i += 1 f = relay.Function(inputs, res) f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - return args_number, input_vars, f + return i, input_vars, f def create_accum_func(args_limit): out = None @@ -915,25 +907,24 @@ def create_accum_func(args_limit): out = relay.Call(f, inputs) return relay.Function(inputs, out) - added_args = 0 - while added_args < n: - # When out is not none that means that one additional argument - # will be used for the result of previous fusing - args_number = n - added_args - a_num, inp, func = create_base_funcs_sum(args_number, args_limit, out) - added_args += a_num - inputs.append(inp[0]) - if len(inp) > 1: - if out is not None: - inp[1] = out - else: - inputs.append(inp[1]) - else: - if out is not None: - inp.append(out) - if len(inp) > 2: - inputs.extend(inp[2:]) - out = relay.Call(func, inp) + i, inputs, func = create_fused_func(args_limit) + out = relay.Call(func, inputs) + while i < n: + inp, func = _base_func(f"p{i}") + inputs.append(relay.var(f"xa{i}", shape=shape)) + curr_args = 1 + number_of_any_dims + if number_of_any_dims > 0: + curr_args += ndims + f = relay.Function([inp], func) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + w = relay.Call(f, [inputs[-1]]) + a = relay.var(f"a", shape=shape) + b = relay.var(f"b", shape=shape) + out_add = relay.add(a, b) + f = relay.Function([a, b], out_add) + f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + out = relay.Call(f, [out, w]) + i += 1 return relay.Function(inputs, out) args_limit = tvm.target.Target.current().max_function_args - ( diff --git a/tests/python/relay/test_pass_split_args.py b/tests/python/relay/test_pass_split_args.py index 83f224d25e671..508f74f11269c 100644 --- a/tests/python/relay/test_pass_split_args.py +++ b/tests/python/relay/test_pass_split_args.py @@ -62,7 +62,6 @@ def expected(limit): if number_of_any_dims > 0: limit -= ndims - last_op = None new_args = [] added_args = 0 num_inputs = 0 @@ -73,30 +72,24 @@ def expected(limit): num_inputs += curr_args if added_args + curr_args > limit: t = relay.Tuple(new_args) - concat = relay.op.concatenate(t, axis) - last_op = relay.annotation.stop_fusion(concat) - new_args = [last_op] - added_args = 1 + stop = relay.annotation.stop_fusion(t) + concat = relay.op.concatenate(stop, axis) + new_args = [concat] + added_args = curr_args added_args += curr_args new_args.append(inp) t = relay.Tuple(new_args) - concat = relay.op.concatenate(t, axis) - last_op = relay.annotation.stop_fusion(concat) + stop = relay.annotation.stop_fusion(t) + concat = relay.op.concatenate(stop, axis) if num_inputs < limit: return before() - return last_op + return concat # the fold constant should work on any context. - res = run_opt_pass( - before(), - transform.SplitArgs(tvm.target.Target(target_name).max_function_args), - ) - exp = run_opt_pass( - expected(tvm.target.Target(target_name).max_function_args), transform.InferType() - ) limit = tvm.target.Target(target_name).max_function_args + res = run_opt_pass(before(), transform.SplitArgs(limit)) exp = run_opt_pass(expected(limit), transform.InferType()) assert tvm.ir.structural_equal(res, exp) diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index b03999c9fa6a1..dcb43f29daf49 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -219,6 +219,62 @@ def _check(target, n, dtype): _check(target, 32, "float32") +def _get_maximum_kernel_args(source): + def get_kernel_args(source): + import re + + p = re.compile(r"__kernel void .+\((.*)\)") + args = p.findall(source) + return args + + args = get_kernel_args(source) + max_args = len(args[0].split(",")) + for arg_line in args: + max_args = max(max_args, len(arg_line.split(","))) + return max_args + + +def _validate_opencl_executors(executor_type, get_model, ref_impl): + from tvm.contrib import graph_executor + from tvm.runtime.vm import VirtualMachine + + input_dict, model = get_model() + if executor_type == "ge": + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(model, target_host="llvm", target=target) + ocl_lib = lib.get_lib() + else: + module = tvm.IRModule({}) + module["main"] = model + with tvm.transform.PassContext(opt_level=3): + lib = relay.vm.compile(module, target=target, target_host="llvm") + ocl_lib = lib.module.imported_modules[0] + opencl_modules = list(filter(lambda mod: mod.type_key == "opencl", ocl_lib.imported_modules)) + assembly = opencl_modules[0].get_source() + with tvm.target.Target(target): + limit = tvm.target.Target.current().max_function_args + max_num = _get_maximum_kernel_args(assembly) + assert max_num <= limit + + dev = tvm.cl() + if executor_type == "ge": + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input(**input_dict) + module.run() + tvm_out = module.get_output(0) + else: + vm = VirtualMachine(lib, dev, "naive") + data = {} + for k, v in input_dict.items(): + data[k] = tvm.nd.array(v, dev) + vm.set_input("main", **data) + vm.invoke_stateful("main") + tvm_out = vm.get_outputs()[0] + + np_result = ref_impl(list(input_dict.values())) + np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + shape_type = tvm.testing.parameter("dynamic", "static") executor_type = tvm.testing.parameter("ge", "vm") @@ -249,65 +305,9 @@ def ref_impl(inputs): axis = 1 return np.concatenate(tuple(inputs), axis=axis) - def get_maximum_kernel_args(source): - def get_kernel_args(source): - import re - - p = re.compile(r"__kernel void .+\((.*)\)") - args = p.findall(source) - return args - - args = get_kernel_args(source) - max_args = len(args[0].split(",")) - for arg_line in args: - max_args = max(max_args, len(arg_line.split(","))) - return max_args - - def validate(): - from tvm.contrib import graph_executor - from tvm.runtime.vm import VirtualMachine - - input_dict, model = _get_model() - if executor_type == "ge": - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(model, target_host="llvm", target=target) - ocl_lib = lib.get_lib() - else: - module = tvm.IRModule({}) - module["main"] = model - with tvm.transform.PassContext(opt_level=1): - lib = relay.vm.compile(module, target=target, target_host="llvm") - ocl_lib = lib.module.imported_modules[0] - opencl_modules = list( - filter(lambda mod: mod.type_key == "opencl", ocl_lib.imported_modules) - ) - assembly = opencl_modules[0].get_source() - with tvm.target.Target(target): - limit = tvm.target.Target.current().max_function_args - max_num = get_maximum_kernel_args(assembly) - assert max_num <= limit - - dev = tvm.cl() - if executor_type == "ge": - module = graph_executor.GraphModule(lib["default"](dev)) - module.set_input(**input_dict) - module.run() - tvm_out = module.get_output(0) - else: - vm = VirtualMachine(lib, dev, "naive") - data = {} - for k, v in input_dict.items(): - data[k] = tvm.nd.array(v, dev) - vm.set_input("main", **data) - vm.invoke_stateful("main") - tvm_out = vm.get_outputs()[0] - - np_result = ref_impl(list(input_dict.values())) - np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-2, atol=1e-2) - if executor_type == "ge" and shape_type == "dynamic": pytest.skip() - validate() + _validate_opencl_executors(executor_type, _get_model, ref_impl) @tvm.testing.requires_gpu @@ -315,10 +315,11 @@ def validate(): def test_opencl_fuse_max_args(executor_type, shape_type): if shape_type == "dynamic": shape = (tvm.tir.Any(), 20) + ops_num = 80 else: shape = (1, 20) + ops_num = 300 shape_np = (1, 20) - ops_num = 300 dtype = "float32" def _base_func(name): @@ -347,65 +348,81 @@ def ref_impl(inputs): w = w + np.exp(inputs[i + 1] + 1) return w - def get_maximum_kernel_args(source): - def get_kernel_args(source): - import re - - p = re.compile(r"__kernel void .+\((.*)\)") - args = p.findall(source) - return args - - args = get_kernel_args(source) - max_args = len(args[0].split(",")) - for arg_line in args: - max_args = max(max_args, len(arg_line.split(","))) - return max_args - - def validate(): - from tvm.contrib import graph_executor - from tvm.runtime.vm import VirtualMachine - - input_dict, model = _get_model() - if executor_type == "ge": - with tvm.transform.PassContext(opt_level=3): - lib = relay.build(model, target_host="llvm", target=target) - ocl_lib = lib.get_lib() - else: - module = tvm.IRModule({}) - module["main"] = model - with tvm.transform.PassContext(opt_level=1): - lib = relay.vm.compile(module, target=target, target_host="llvm") - ocl_lib = lib.module.imported_modules[0] - opencl_modules = list( - filter(lambda mod: mod.type_key == "opencl", ocl_lib.imported_modules) - ) - assembly = opencl_modules[0].get_source() - with tvm.target.Target(target): - limit = tvm.target.Target.current().max_function_args - max_num = get_maximum_kernel_args(assembly) - assert max_num <= limit - - dev = tvm.cl() - if executor_type == "ge": - module = graph_executor.GraphModule(lib["default"](dev)) - module.set_input(**input_dict) - module.run() - tvm_out = module.get_output(0) - else: - vm = VirtualMachine(lib, dev, "naive") - data = {} - for k, v in input_dict.items(): - data[k] = tvm.nd.array(v, dev) - vm.set_input("main", **data) - vm.invoke_stateful("main") - tvm_out = vm.get_outputs()[0] + if executor_type == "ge" and shape_type == "dynamic": + pytest.skip() + _validate_opencl_executors(executor_type, _get_model, ref_impl) + + +@tvm.testing.requires_gpu +@tvm.testing.requires_opencl +def test_fuse_concat_max_num_args(executor_type, shape_type): + """ + In this test before concat we have an operation with 3 inputs. In the + SplitArgs we cannot calculate these inputs as inputs to concat, because + they will be added to the concat after fusing operation. So FuseOps pass + should handle this case and stop fusing before concat. + + The example: + x y z x y z + \ | / \ | / + \ | / \ | / + where ... where + | | + exp exp + \ / + \ / + \-----> concat <-----/ + """ + if shape_type == "dynamic": + shape = (tvm.tir.Any(), 20) + ops_num = 80 + else: + shape = (10, 20) + ops_num = 300 + shape_np = (10, 20) + dtype = "float32" + axis = 1 - np_result = ref_impl(list(input_dict.values())) - np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + def _base_func(name): + x = relay.var(name, shape=shape) + y = relay.var(f"y{name}", shape=shape) + z = relay.var(f"z{name}", shape=shape) + cond = relay.less(x, relay.const(1, "float32")) + l = relay.add(y, relay.const(1, "float32")) + r = relay.add(z, relay.const(5, "float32")) + w = relay.where(cond, l, r) + w = relay.exp(w) + return [x, y, z], w + + def _get_model(): + inp = [] + out = [] + inputs_np = {} + for i in range(ops_num): + inputs, w = _base_func(f"x{i}") + inputs_np[f"x{i}"] = np.random.uniform(size=shape_np).astype(dtype) + inputs_np[f"yx{i}"] = np.random.uniform(size=shape_np).astype(dtype) + inputs_np[f"zx{i}"] = np.random.uniform(size=shape_np).astype(dtype) + inp.extend(inputs) + out.append(w) + t = relay.Tuple(out) + w = relay.op.concatenate(t, axis) + return inputs_np, relay.Function(inp, w) + + def ref_impl(inputs): + res = [] + for i in range(0, len(inputs), 3): + x = inputs[i] + y = inputs[i + 1] + z = inputs[i + 2] + comp = np.where(x < 1, y + 1, z + 5) + comp = np.exp(comp) + res.append(comp) + return np.concatenate(tuple(res), axis=axis) if executor_type == "ge" and shape_type == "dynamic": pytest.skip() - validate() + _validate_opencl_executors(executor_type, _get_model, ref_impl) if __name__ == "__main__":