Skip to content

Commit

Permalink
[Relay] Introduce arguments limit to FuseOps pass
Browse files Browse the repository at this point in the history
In PR apache#8313 a parameter `max_function_args` was introduced. It leads to
limit number of function argument and in case when this value is
exceeded then concatenation layer is split to a several concat
operations.

I faced a problem on Adreno GPU that for kernel with big number of
arguments the enqueueNDRange was crashed without any errors. The
problem appeared because of the huge number of arguments. But in this
case not only concat layer was a root cause of the problem. Also after
fusing several operations the final functions had a big number of
arguments.

As it was discussed in apache#8313, adding a limitation on the number of
function arguments to the FuseOps pass might be a good improvement. In
this PR I introduced such mechanism for limitation number of function
arguments for FuseOps pass and add an arguments limit to OpenCL devices
at 128 parameters.

The idea of current approach is calculate the number of arguments for
each node in fusing algorithm and in case then the number of function
arguments exceeds the limit, specified by `max_function_args`, then the
fusing should be stopped. In case when node has several inputs and for
some of the inputs the number of arguments wasn't computed, then we
postpone fusing for this node and will try fuse this node later when
the number of arguments will be computed for all inputs. This approach
with postponed fusing helps to avoid additional computations during
compilation.

Additionally, case of dynamic shapes should be handled.  In case of
dynamic shape, function arguments also included sizes of dynamic
dimension and strides. The number of strides can be computed by
calculating number of tensor dimensions (the number of strides equals
to the rank of the tensor). The number of additional parameters with
sizes of dynamic dimensions can be calculated by computing number of
dynamic dimensions.
  • Loading branch information
echuraev committed Jul 18, 2023
1 parent b8b4886 commit d4b8036
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 235 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false);
* \brief Split function with huge number of arguments to smaller pieces.
*
* \param max_function_args Maximum number of function arguments. If it is 0 then SplitArgs won't
* split funciton.
* split function.
*
* \return The pass.
*/
Expand Down
9 changes: 6 additions & 3 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from . import _make
from .dyn import _make as _dyn_make
from ..expr import Tuple, Expr, Constant
from ..expr import Tuple, Expr, Constant, Call
from . import op as reg


Expand Down Expand Up @@ -1141,12 +1141,15 @@ def concatenate(data, axis):
result: relay.Expr
The concatenated tensor.
"""
data = list(data)
if not isinstance(data, Call):
data = list(data)
if not data:
raise ValueError("relay.concatenate requires data to be non-empty.")
if not isinstance(data, Call):
data = Tuple(data)
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.concatenate(Tuple(data), axis)
return _make.concatenate(data, axis)


def einsum(data, equation):
Expand Down
160 changes: 102 additions & 58 deletions src/relay/analysis/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
if (child == parent) return;
// update the number of nodes of the parent group
parent->num_nodes += child->num_nodes;
parent->args_num += child->args_num;
child->parent = parent;
// update anchor ref and pattern
if (child->anchor_ref != nullptr) {
Expand All @@ -193,6 +194,10 @@ void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwar
}

void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
if (postpone_node_ != nullptr) {
postponed_fusing_map_.insert({postpone_node_, src});
return;
}
Group* target = groups_[sink->index];
visited_.clear();
ICHECK(src != sink);
Expand Down Expand Up @@ -220,6 +225,45 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node*
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
}

size_t GraphPartitioner::CountArgs_(IndexedForwardGraph::Node* src,
const IndexedForwardGraph& graph, bool update_postpone) {
std::unordered_set<Group*> visited_groups;
Group* gnode = groups_[src->index];
ICHECK(gnode != nullptr);
auto sum = gnode->args_num;
visited_groups.insert(gnode->FindRoot());
auto calcArgs = [this, src, &graph, &visited_groups,
update_postpone](const relay::Expr& arg) -> size_t {
if (arg.as<VarNode>()) return 0;
auto* node = graph.node_map.at(arg.get());
Group* prev_group = groups_[node->index]->FindRoot();
if (visited_groups.count(prev_group) == 0) {
visited_groups.insert(prev_group);
if (prev_group->args_num > 0) {
// Get number of arguments from group
return prev_group->args_num;
} else if (update_postpone) {
// Update pointer to node which should be postponed for deferred fusing
postpone_node_ = src;
} else {
// Calculate number of arguments for the node which wasn't processed before
return CountArgs_(node, graph, update_postpone);
}
}
return 0;
};
if (auto call_node = GetRef<ObjectRef>(src->ref).as<CallNode>()) {
for (auto& it : call_node->args) {
sum += calcArgs(it);
}
} else if (auto tuple_node = GetRef<ObjectRef>(src->ref).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
sum += calcArgs(it);
}
}
return sum;
}

size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides) {
size_t any_dims = 0;
for (const auto& dim : ttype->shape) {
Expand All @@ -231,48 +275,6 @@ size_t GraphPartitioner::CountAdditionalArgs_(const TensorTypeNode* ttype, bool
return any_dims;
}

size_t GraphPartitioner::CountArgs_(const tvm::Object* child, const tvm::Object* till_node) {
if (child == till_node) {
// Calculate number of output arguments
if (auto call_node = GetRef<ObjectRef>(child).as<CallNode>()) {
if (const auto* ttype = call_node->checked_type().as<TensorTypeNode>()) {
return CountAdditionalArgs_(ttype) + 1;
}
}
return 1;
}
if (argsMap_.count(child)) {
return argsMap_[child];
}
size_t args_num = 0;
if (auto call_node = GetRef<ObjectRef>(child).as<CallNode>()) {
for (auto& it : call_node->args) {
if (it.as<CallNode>() || it.as<TupleNode>()) {
args_num += CountArgs_(it.get(), till_node);
} else if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (GetRef<ObjectRef>(child).as<VarNode>() ||
GetRef<ObjectRef>(child).as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype =
GetRef<ObjectRef>(child).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
} else if (auto tuple_node = GetRef<ObjectRef>(child).as<TupleNode>()) {
for (const auto& it : tuple_node->fields) {
args_num++;
args_num += CountArgs_(it.get(), till_node);
}
}
argsMap_[child] = args_num;
return args_num;
}

size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child) {
auto* outputs_list = child->outputs.head;
size_t output_args = 0;
Expand All @@ -288,21 +290,47 @@ size_t GraphPartitioner::CountArgsLimit_(const IndexedForwardGraph::Node* child)
return (max_function_args_ > output_args) ? max_function_args_ - output_args : 0;
}

size_t GraphPartitioner::CountFusedArgs(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* till_node) {
const tvm::Object* till_node_ref = (till_node != nullptr) ? till_node->ref : nullptr;
size_t GraphPartitioner::CountFusedArgs(const IndexedForwardGraph& graph,
IndexedForwardGraph::Node* child) {
size_t args_num = 0;
auto* outputs_list = child->outputs.head;
size_t res = 1;
while (outputs_list != nullptr) {
size_t output_args = 0;
output_args += CountArgs_(outputs_list->value.node->ref, till_node_ref);
res = std::max(res, output_args);
args_num = std::max(args_num, CountArgs_(outputs_list->value.node, graph));
outputs_list = outputs_list->next;
}
return res;
return args_num;
}

void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
auto args_counter = [this](const tvm::Object* obj) {
size_t args_num = 0;
if (auto call_node = GetRef<ObjectRef>(obj).as<CallNode>()) {
for (auto& it : call_node->args) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (auto tuple_node = GetRef<ObjectRef>(obj).as<TupleNode>()) {
for (auto& it : tuple_node->fields) {
if (it.as<VarNode>() || it.as<TupleGetItemNode>()) {
args_num++;
if (const auto* ttype = it.as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
}
} else if (GetRef<ObjectRef>(obj).as<VarNode>()) {
args_num++;
if (const auto* ttype =
GetRef<ObjectRef>(obj).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
args_num += CountAdditionalArgs_(ttype);
}
}
return args_num;
};
groups_.resize(graph.post_dfs_order.size());
for (size_t nid = 0; nid < groups_.size(); ++nid) {
const auto* graph_node = graph.post_dfs_order[nid];
Expand All @@ -313,21 +341,33 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
if (group_node->pattern == relay::kOutEWiseFusable) {
group_node->anchor_ref = graph_node->ref;
}
group_node->args_num = args_counter(graph_node->ref);
groups_[nid] = group_node;
}
}

void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
const DominatorTree& post_dom_tree, //
int phase) {
IndexedForwardGraph::Node* prev_node = nullptr;
argsMap_.clear();
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
auto* dom_node = post_dom_tree.nodes[nid];
Group* group_node = groups_[nid];
ICHECK(group_node != nullptr);
postpone_node_ = nullptr;
if (postponed_fusing_map_.count(graph_node)) {
auto range = postponed_fusing_map_.equal_range(graph_node);
for (auto it = range.first; it != range.second; ++it) {
if (CountArgs_(graph_node, graph, false) <= CountArgsLimit_(graph_node)) {
auto* src = it->second;
auto* snode = post_dom_tree.nodes[src->index]->parent->gnode;
if (groups_[snode->index]->anchor_ref != nullptr) continue;
CommitFuse(src, snode);
}
}
postponed_fusing_map_.erase(graph_node);
}
// no actions for opaque nodes
if (group_node->pattern == kOpaque) continue;
// no actions needed if the current node have no dominator
Expand All @@ -339,12 +379,16 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;
// refuse the fusion if too many arguments are going to be in fused function
auto limit = CountArgsLimit_(graph_node);
if (limit > 0) {
if (CountFusedArgs(graph_node, prev_node) > limit) {
argsMap_.clear();
prev_node = graph_node;
continue;
if (max_function_args_ > 0) {
auto limit = CountArgsLimit_(graph_node);
if (limit > 0) {
auto args = CountFusedArgs(graph, graph_node);
//std::cout << "args: " << args << ", limit: " << limit << std::endl;
if (args > limit) {
//std::cout << " >>> args: " << args << ", limit: " << limit << std::endl;
//Dump(graph_node->ref);
continue;
}
}
}

Expand Down
34 changes: 29 additions & 5 deletions src/relay/analysis/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ class GraphPartitioner {
* \brief The number of nodes belonging to this group
*/
uint32_t num_nodes{1};
/*!
* \brief The number of function arguments belonging to this group
*/
size_t args_num{0};

/*! \brief Optional attributes to annotate the grouped function. */
runtime::Map<runtime::String, ObjectRef> attrs;
Expand Down Expand Up @@ -215,8 +219,15 @@ class GraphPartitioner {
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
std::unordered_set<IndexedForwardGraph::Node*> visited_;
/*! \brief internal field used for hashing arguments number for a node */
std::unordered_map<const tvm::Object*, size_t> argsMap_;
/*! \brief The map with nodes which were postponed for fusing. */
std::unordered_multimap<const IndexedForwardGraph::Node*, IndexedForwardGraph::Node*>
postponed_fusing_map_;
/*!
* \brief Fusing of this node should be postponed till all child nodes will be evaluated.
* It is used to calculate number of arguments which will be passed to this node in
* generated function.
*/
const IndexedForwardGraph::Node* postpone_node_{nullptr};
// Internal implementation of CheckPath
template <typename F>
bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond);
Expand Down Expand Up @@ -255,8 +266,22 @@ class GraphPartitioner {
void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);

size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
// Count the number of additional arguments. In case of dynamic shape,
// generated function takes several additional arguments, such as size of
// dynamic dimension and strides.
// This function calculates number of such additional arguments.
size_t CountAdditionalArgs_(const TensorTypeNode* ttype, bool with_strides = true);
size_t CountArgs_(const tvm::Object* child, const tvm::Object* till_node);
// Calculate the number of arguments for the node.
size_t CountArgs_(IndexedForwardGraph::Node* src, const IndexedForwardGraph& graph,
bool update_postpone = true);
// Count actual limit of arguments for a generated function.
// max_function_args_ specifies the number of maximum function arguments. But
// usually, output tensors also passed to the function as arguments.
// Additionally, in case of dynamic shape, it is necessary to take into
// account the number of parameters which specifies the size of dynamic
// dimension.
// This function computes limit of arguments by the following formula:
// limit = max_function_args_ - output_args_count
size_t CountArgsLimit_(const IndexedForwardGraph::Node* child);

// Count the number of nodes in a fused subgraph if child is additionally fused.
Expand All @@ -270,8 +295,7 @@ class GraphPartitioner {
// Count the number of arguments in a fused subgraph if child is additionally fused.
// Calculation goes from child node till the till_node. If till_node is a
// nullptr then calculate arguments till the beginning of the graph.
size_t CountFusedArgs(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* till_node = nullptr);
size_t CountFusedArgs(const IndexedForwardGraph& graph, IndexedForwardGraph::Node* child);

// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph);
Expand Down
12 changes: 6 additions & 6 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class ArgumentSplitter : public ExprRewriter {

Expr ConcatSplitter(const TupleNode* tuple_node, const tvm::Array<relay::Expr>& args, int axis,
size_t limit) {
relay::Expr lastExpr;
tvm::Array<relay::Expr> new_args;
size_t added_args = 0;
for (const auto& it : args) {
Expand All @@ -47,18 +46,18 @@ class ArgumentSplitter : public ExprRewriter {
}
if (added_args + curr_args > limit) {
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), new_args);
Expr body = MakeConcatenate(new_tuple, axis);
lastExpr = StopFusion(body);
Expr stop = StopFusion(new_tuple);
Expr lastExpr = MakeConcatenate(stop, axis);
new_args.clear();
new_args.push_back(lastExpr);
added_args = 1;
added_args = curr_args;
}
added_args += curr_args;
new_args.push_back(it);
}
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), new_args);
Expr body = MakeConcatenate(new_tuple, axis);
lastExpr = StopFusion(body);
Expr stop = StopFusion(new_tuple);
Expr lastExpr = MakeConcatenate(stop, axis);
return lastExpr;
}

Expand All @@ -83,6 +82,7 @@ class ArgumentSplitter : public ExprRewriter {
if (max_function_args_ == 0) return post;
if (call->op == concat_op_) {
auto tuple_node = call->args[0].as<TupleNode>();
if (tuple_node == nullptr) return post;
const auto param = call->attrs.as<ConcatenateAttrs>();
size_t outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
Expand Down
Loading

0 comments on commit d4b8036

Please sign in to comment.