Skip to content

Commit

Permalink
[Relay] Introduce arguments limit to FuseOps pass
Browse files Browse the repository at this point in the history
In PR apache#8313 a parameter `max_function_args` was introduced. It leads to
limit number of function argument and in case when this value is
exceeded then concatenation layer is split to a several concat
operations.

I faced a problem on Adreno GPU that for kernel with big number of
arguments the enqueueNDRange was crashed without any errors. The problem
appeared because of the huge number of arguments. But in this case not
only concat layer was a root cause of the problem. Also after fusing
several operations the final functions had a big number of arguments.

As it was discussed in apache#8313, adding a limitation on the number of
function arguments to the FuseOps pass might be a good improvement. In
this PR I introduced such mechanism for limitation number of function
arguments for FuseOps pass and add an arguments limit to OpenCL devices
at 128 parameters.
  • Loading branch information
echuraev committed Jun 21, 2023
1 parent b37ad17 commit 68470b4
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 51 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ TVM_DLL Pass FoldConstant(bool fold_qnn = false);
*
* \return The pass.
*/
TVM_DLL Pass SplitArgs(int max_function_args);
TVM_DLL Pass SplitArgs(size_t max_function_args);

/*!
* \brief Fuse operations into expr into separate functions.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def max_shared_memory_per_block(self):

@property
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))
return int(self.attrs.get("max_function_args", 0))

@property
def vtcm_capacity(self):
Expand Down
30 changes: 30 additions & 0 deletions src/relay/analysis/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,29 @@ size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node*
return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent);
}

size_t GraphPartitioner::CountArgs_(const tvm::Object* child, const tvm::Object* till_node) {
if (child == till_node) return 1;
auto args_num = 0;
if (auto call_node = GetRef<ObjectRef>(child).as<CallNode>()) {
for (auto& it : call_node->args) {
// Check if call_node or constant
if (it.as<CallNode>() || it.as<ConstantNode>()) {
args_num += CountArgs_(it.get(), till_node);
} else {
args_num++;
}
}
}
return args_num;
}

size_t GraphPartitioner::CountFusedArgs(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* till_node) {
const tvm::Object* till_node_ref = (till_node != nullptr) ? till_node->ref : nullptr;
auto args = CountArgs_(child->ref, till_node_ref);
return args;
}

void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
groups_.resize(graph.post_dfs_order.size());
for (size_t nid = 0; nid < groups_.size(); ++nid) {
Expand All @@ -238,6 +261,7 @@ void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
const DominatorTree& post_dom_tree, //
int phase) {
IndexedForwardGraph::Node* prev_node = nullptr;
for (size_t nid = 0; nid < groups_.size(); ++nid) {
// the group of current node has been specified already.
auto* graph_node = graph.post_dfs_order[nid];
Expand All @@ -254,6 +278,12 @@ void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, //
// refuse the fusion if too many ops are going to be fused together
if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)
continue;
// refuse the fusion if too many arguments are going to be in fused function
// max_function_args_ - 1 because one argument contains buffer with output
if (CountFusedArgs(graph_node, prev_node) >= max_function_args_ - 1) {
prev_node = graph_node;
continue;
}

if (phase == 2) {
// Fuse injective ops into intermediate tuples, if any
Expand Down
13 changes: 10 additions & 3 deletions src/relay/analysis/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class IndexedForwardGraph {
std::vector<Node*> post_dfs_order;

/*! \brief Dump the graph into string. */
void DebugDump() {
void DebugDump() const {
std::ostringstream os;
for (size_t i = 0; i < post_dfs_order.size(); ++i) {
Node* node = post_dfs_order[i];
Expand Down Expand Up @@ -162,8 +162,8 @@ class DominatorTree {
*/
class GraphPartitioner {
public:
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth)
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {}
explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth, size_t max_function_args)
: arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth), max_function_args_(max_function_args) {}
/*!
* \brief Group as a union find data structure.
*/
Expand Down Expand Up @@ -205,6 +205,8 @@ class GraphPartitioner {
int opt_level_;
/*! \brief The maximum number of operations in one fused function */
size_t max_fuse_depth_;
/*! \brief The maximum number of arguments in one fused function */
size_t max_function_args_;
/*! \brief The internal groups. */
std::vector<Group*> groups_;
/*! \brief internal field used for deduplication */
Expand Down Expand Up @@ -247,6 +249,7 @@ class GraphPartitioner {
void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);

size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink);
size_t CountArgs_(const tvm::Object* child, const tvm::Object* till_node);

// Count the number of nodes in a fused subgraph if child is additionally fused.
// dom_parent is already known to be a part of the subgraph.
Expand All @@ -256,6 +259,10 @@ class GraphPartitioner {
// is important for correct calculation.
size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child,
IndexedForwardGraph::Node* dom_parent);
// Count the number of arguments in a fused subgraph if child is additionally fused.
// Calculation goes from child node till the till_node. If till_node is a
// nullptr then calculate arguments till the beginning of the graph.
size_t CountFusedArgs(IndexedForwardGraph::Node* child, IndexedForwardGraph::Node* till_node = nullptr);

// Initialize the groups.
void InitGroups(const IndexedForwardGraph& graph);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if (config_->optional_homogeneous_target.defined()) {
// This pass currently only supports the homogeneous case.
pass_seqs.push_back(transform::SplitArgs(
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", -1)
config_->optional_homogeneous_target->GetAttr<Integer>("max_function_args", 0)
.value()
.IntValue()));
}
Expand Down
23 changes: 16 additions & 7 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,10 @@ class IndexedForwardGraphCreator : private ExprVisitor {

class FuseMutator : private MixedModeMutator {
public:
FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params)
FuseMutator(int fuse_opt_level, size_t max_fuse_depth, size_t max_function_args, bool link_params)
: fuse_opt_level_(fuse_opt_level),
max_fuse_depth_(max_fuse_depth),
max_function_args_(max_function_args),
link_params_(link_params) {}

// Run the transform
Expand All @@ -334,7 +335,8 @@ class FuseMutator : private MixedModeMutator {
Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
// setup the group map.
auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph);
auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)
.Partition(graph);
for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
Expand All @@ -347,6 +349,7 @@ class FuseMutator : private MixedModeMutator {
private:
int fuse_opt_level_;
size_t max_fuse_depth_;
size_t max_function_args_;
bool link_params_;

using MixedModeMutator::VisitExpr_;
Expand Down Expand Up @@ -548,9 +551,10 @@ class FuseMutator : private MixedModeMutator {
}
};

Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, bool link_params,
const IRModule& module) {
return FuseMutator(fuse_opt_level, max_fuse_depth, link_params).Transform(expr);
Expr FuseOps(const Expr& expr, int fuse_opt_level, size_t max_fuse_depth, size_t max_function_args,
bool link_params, const IRModule& module) {
return FuseMutator(fuse_opt_level, max_fuse_depth, max_function_args, link_params)
.Transform(expr);
}

namespace transform {
Expand All @@ -567,8 +571,13 @@ Pass FuseOps(int fuse_opt_level) {
link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value();
int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
return Downcast<Function>(
FuseOps(f, opt_level, max_fuse_depth.value().IntValue(), link_params, m));
auto target = Target::Current();
size_t max_function_args =
(target.defined())
? target->GetAttr<Integer>("max_function_args", Integer(0)).value().IntValue()
: 0;
return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value().IntValue(),
max_function_args, link_params, m));
};
return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"});
}
Expand Down
17 changes: 9 additions & 8 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,21 @@ namespace relay {

class ArgumentSplitter : public ExprRewriter {
public:
explicit ArgumentSplitter(int max_function_args)
explicit ArgumentSplitter(size_t max_function_args)
: max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (max_function_args_ < 0) return post;
if (max_function_args_ == 0) return post;
if (call->op == concat_op_) {
auto tuple_node = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
int outputsNum = 1;
size_t outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
int argsNum = tuple_node->fields.size();
CHECK_GT(max_function_args_, outputsNum);
const size_t limit = max_function_args_ - outputsNum;
size_t argsNum = tuple_node->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
Expand All @@ -71,18 +72,18 @@ class ArgumentSplitter : public ExprRewriter {
}

private:
const int max_function_args_;
const size_t max_function_args_;
const Op& concat_op_;
};

Expr SplitArgs(const Expr& expr, int max_function_args) {
Expr SplitArgs(const Expr& expr, size_t max_function_args) {
auto rewriter = ArgumentSplitter(max_function_args);
return PostOrderRewrite(expr, &rewriter);
}

namespace transform {

Pass SplitArgs(int max_function_args) {
Pass SplitArgs(size_t max_function_args) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto r = Downcast<Function>(SplitArgs(f, max_function_args));
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {

// Buffer arguments
size_t num_buffer = 0;
int limit = target_->GetAttr<Integer>("max_function_args").value().IntValue();
if (static_cast<int>(f->params.size()) > limit) {
size_t limit = target_->GetAttr<Integer>("max_function_args").value().IntValue();
if (f->params.size() > limit) {
LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
"buffers in the kernel";
}
Expand Down
5 changes: 5 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,15 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.set_default_keys({"rocm", "gpu"})
.set_target_parser(UpdateROCmAttrs);

// Faced that Qualcomm OpenCL runtime was crashed without any error message in
// case when the number of kernel arguments was pretty big. OpenCL doesn't
// specify any limitations on the number of kernel arguments. max_function_args
// equals to 128 looks like a reasonable number of kernel arguments.
TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.add_attr_option<Integer>("texture_spatial_limit", Integer(16384))
.add_attr_option<Integer>("max_function_args", Integer(128))
.set_default_keys({"opencl", "gpu"});

// The metal has some limitations on the number of input parameters. This is why attribute
Expand Down
87 changes: 87 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,12 @@ def expected(n, max_fused_ops):

assert tvm.ir.structural_equal(zz, after)

with tvm.target.Target("opencl"):
with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
cl_zz = run_opt_pass(z, transform.FuseOps())

assert tvm.ir.structural_equal(cl_zz, after)


link_params = tvm.testing.parameter(False, True)

Expand Down Expand Up @@ -828,5 +834,86 @@ def expected():
tvm.testing.assert_allclose(result, ref, rtol=1e-4, atol=1e-4)


target_name = tvm.testing.parameter("opencl", "metal", "cuda")


def test_fuse_max_num_args(target_name):
def _base_func(name):
x = relay.var(name, shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.exp(y)
w = relay.squeeze(z)
return x, w

def before(n):
inp = []
out = []
for i in range(n):
x, w = _base_func(f"x{i}")
inp.append(x)
out.append(w)
w = out[0]
for i in range(len(out) - 1):
w = relay.add(w, out[i + 1])
return relay.Function(inp, w)

def after(n):
def create_base_funcs_sum(args_num, prev=None):
added_args = 0
inputs = []
input_vars = []
res = None
for i in range(args_num):
added_args += 1
inp, out = _base_func(f"p{i}")
input_vars.append(relay.var(f"x{i}", shape=(10, 20)))
inputs.append(inp)
if res is None:
res = out
elif i == 1 and prev is not None:
res = relay.add(inp, res)
added_args -= 1
else:
res = relay.add(res, out)
f = relay.Function(inputs, res)
f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
return added_args, input_vars, f

def create_accum_func(args_limit):
out = None
inputs = []
added_args = 0
while added_args < n:
# When out is not none that means that one additional argument
# will be used for the result of previous fusing
args_num = n - added_args if out is None else n - added_args + 1
if args_limit > 0:
args_num = min(args_num, args_limit)
a_num, inp, func = create_base_funcs_sum(args_num, out)
added_args += a_num
inputs.append(inp[0])
if out is not None:
inp[1] = out
else:
inputs.append(inp[1])
inputs.extend(inp[2:])
out = relay.Call(func, inp)
return inputs, out

args_limit = tvm.target.Target.current().max_function_args - 1 # one buffer with output
args_limit = max(args_limit, 0)
inp, w = create_accum_func(args_limit)
return relay.Function(inp, w)

ops_num = 300
max_fused_ops = ops_num * 5
with tvm.target.Target(target_name):
with tvm.transform.PassContext(config={"relay.FuseOps.max_depth": max_fused_ops}):
fused = run_opt_pass(before(ops_num), transform.FuseOps())

expected = run_opt_pass(after(ops_num), transform.InferType())
assert tvm.ir.structural_equal(fused, expected)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 68470b4

Please sign in to comment.