Skip to content

Commit

Permalink
[Metal] Add pass for splitting kernel with huge number of args (apach…
Browse files Browse the repository at this point in the history
…e#8313)

* [Metal] Add pass for splitting kernel with huge number of args

The Metal has some limitations on the number of input parameters. More
information can be found here:
https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc

In this commit a new pass for splitting functions with big number of
arguments to smaller parts was added. In parameter `max_function_args`
we can specify the maximum number of kernel arguments for specific
target and then split kernel when the number of arguments exceeds the
value of `max_function_args`. Currently this pass works only for concat
layer.

* Add getting number of output parameters

* Fix CI and apply comments
  • Loading branch information
echuraev authored and ylc committed Sep 29, 2021
1 parent 9b42371 commit fa4e6a9
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 3 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ TVM_DLL Pass LazyGradientInit();
*/
TVM_DLL Pass FoldConstant();

/*!
* \brief Split function with huge number of arguments to smaller pieces.
*
* \return The pass.
*/
TVM_DLL Pass SplitArgs(int max_function_args);

/*!
* \brief Fuse operations into expr into seperate functions.
*
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,3 +1228,14 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1):
if missing_op_mode < 0 or missing_op_mode > 2:
raise ValueError("Missing op mode is either 0, 1, or 2")
return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode)


def SplitArgs(max_function_args):
"""Split function with huge number of arguments to smaller pieces.
Returns
-------
ret : tvm.transform.Pass
The registered pass for constant folding.
"""
return _ffi_api.SplitArgs(max_function_args)
4 changes: 4 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def max_num_threads(self):
def thread_warp_size(self):
return int(self.attrs["thread_warp_size"])

@property
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))

@property
def device_name(self):
return str(self.attrs.get("device", ""))
Expand Down
6 changes: 6 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,12 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());

if (targets.size() == 1) {
const auto& target = (*targets.begin()).second;
pass_seqs.push_back(
transform::SplitArgs(target->GetAttr<Integer>("max_function_args", -1).value()));
}

// Create a sequential pass and perform optimizations.
transform::Pass seq = transform::Sequential(pass_seqs);
if (targets.size() == 1) {
Expand Down
95 changes: 95 additions & 0 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file split_args.cc
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include "pattern_utils.h"

namespace tvm {
namespace relay {

class ArgumentSplitter : public ExprRewriter {
public:
explicit ArgumentSplitter(int max_function_args)
: max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (max_function_args_ < 0) return post;
if (call->op == concat_op_) {
auto op = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
int outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
int argsNum = op->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;

std::vector<Expr> splitted(splitNum);
for (int i = 0; i < splitNum; ++i) {
int startIdx = i * limit;
int argsCount = std::min(limit, argsNum - startIdx);
tvm::Array<Expr> args;
for (int j = 0; j < argsCount; ++j) {
args.push_back(op->fields[j + startIdx]);
}
Tuple tuple(args);
Expr body = MakeConcatenate(tuple, param->axis);
splitted[i] = StopFusion(body);
}
tvm::Array<Expr> tupleArgs(splitted);
Tuple tuple(tupleArgs);
return MakeConcatenate(tuple, param->axis);
}
return post;
}

private:
const int max_function_args_;
const Op& concat_op_;
};

Expr SplitArgs(const Expr& expr, int max_function_args) {
auto rewriter = ArgumentSplitter(max_function_args);
return PostOrderRewrite(expr, &rewriter);
}

namespace transform {

Pass SplitArgs(int max_function_args) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SplitArgs(f, max_function_args));
};
return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs);

} // namespace transform

} // namespace relay
} // namespace tvm
9 changes: 7 additions & 2 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void CodeGenMetal::InitFuncState(const PrimFunc& f) {
}
}

CodeGenMetal::CodeGenMetal() {
CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
decl_stream << "#include <metal_stdlib>\n";
decl_stream << "using namespace metal;\n\n";
decl_stream << "union __TVMArgUnion {\n"
Expand All @@ -67,6 +67,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {

// Buffer arguments
size_t num_buffer = 0;
int limit = target_->GetAttr<Integer>("max_function_args").value();
if (static_cast<int>(f->params.size()) > limit) {
LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of "
"buffers in the kernel";
}
for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
Var v = f->params[i];
if (!v.dtype().is_handle()) break;
Expand Down Expand Up @@ -332,7 +337,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
code << "// Function: " << kv.first->name_hint << std::endl;
CodeGenMetal cg;
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
Expand Down
3 changes: 2 additions & 1 deletion src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace codegen {

class CodeGenMetal final : public CodeGenC {
public:
CodeGenMetal();
explicit CodeGenMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
void AddFunction(const PrimFunc& f); // NOLINT(*)
Expand All @@ -58,6 +58,7 @@ class CodeGenMetal final : public CodeGenC {

private:
int thread_index_bits_{32};
Target target_;
};
} // namespace codegen
} // namespace tvm
Expand Down
5 changes: 5 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,15 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<Integer>("thread_warp_size", Integer(1))
.set_default_keys({"opencl", "gpu"});

// The metal has some limitations on the number of input parameters. This is why attribute
// `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More
// information about this limitation can be found here:
// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc
TVM_REGISTER_TARGET_KIND("metal", kDLMetal)
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(16))
.add_attr_option<Integer>("max_function_args", Integer(31))
.set_default_keys({"metal", "gpu"});

TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
Expand Down
96 changes: 96 additions & 0 deletions tests/python/relay/test_pass_split_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload


def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, tvm.transform.Pass)

mod = tvm.IRModule.from_expr(expr)
mod = relay.transform.InferType()(mod)
mod = opt_pass(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body


def test_split_concat_metal():
shape = (1, 1, 1, 3)
dtype = "float32"
axis = 1
inputs = []
for i in range(100):
inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype))

def before():
inp = relay.Tuple(inputs)
return relay.op.concatenate(inp, axis)

def expected():
limit = tvm.target.Target("metal").max_function_args - 1 # one buffer with output
splitNum = int(len(inputs) / limit)
if len(inputs) % limit > 0:
splitNum += 1

splitted = []
for i in range(splitNum):
startIdx = i * limit
argsCount = min(limit, len(inputs) - startIdx)
args = []
for j in range(argsCount):
args.append(inputs[j + startIdx])
t = relay.Tuple(args)
concat = relay.op.concatenate(t, axis)
splitted.append(relay.annotation.stop_fusion(concat))
inp = relay.Tuple(splitted)
return relay.op.concatenate(inp, axis)

# the fold constant should work on any context.
res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("metal").max_function_args))
exp = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(res, exp)


def test_split_concat_cuda():
shape = (1, 1, 1, 3)
dtype = "float32"
axis = 1
inputs = []
for i in range(100):
inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype))

def before():
inp = relay.Tuple(inputs)
return relay.op.concatenate(inp, axis)

def expected():
inp = relay.Tuple(inputs)
return relay.op.concatenate(inp, axis)

# the fold constant should work on any context.
res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("cuda").max_function_args))
exp = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(res, exp)


if __name__ == "__main__":
test_split_concat_metal()
test_split_concat_cuda()

0 comments on commit fa4e6a9

Please sign in to comment.