forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Metal] Add pass for splitting kernel with huge number of args (apach…
…e#8313) * [Metal] Add pass for splitting kernel with huge number of args The Metal has some limitations on the number of input parameters. More information can be found here: https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc In this commit a new pass for splitting functions with big number of arguments to smaller parts was added. In parameter `max_function_args` we can specify the maximum number of kernel arguments for specific target and then split kernel when the number of arguments exceeds the value of `max_function_args`. Currently this pass works only for concat layer. * Add getting number of output parameters * Fix CI and apply comments
- Loading branch information
Showing
9 changed files
with
233 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file split_args.cc | ||
*/ | ||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
|
||
#include "pattern_utils.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
class ArgumentSplitter : public ExprRewriter { | ||
public: | ||
explicit ArgumentSplitter(int max_function_args) | ||
: max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {} | ||
|
||
Expr Rewrite_(const CallNode* call, const Expr& post) final { | ||
if (max_function_args_ < 0) return post; | ||
if (call->op == concat_op_) { | ||
auto op = call->args[0].as<TupleNode>(); | ||
const auto param = call->attrs.as<ConcatenateAttrs>(); | ||
int outputsNum = 1; | ||
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) { | ||
outputsNum = tuple_type->fields.size(); | ||
} | ||
const int limit = max_function_args_ - outputsNum; | ||
int argsNum = op->fields.size(); | ||
if (argsNum < limit) return post; | ||
int splitNum = argsNum / limit; | ||
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; | ||
|
||
std::vector<Expr> splitted(splitNum); | ||
for (int i = 0; i < splitNum; ++i) { | ||
int startIdx = i * limit; | ||
int argsCount = std::min(limit, argsNum - startIdx); | ||
tvm::Array<Expr> args; | ||
for (int j = 0; j < argsCount; ++j) { | ||
args.push_back(op->fields[j + startIdx]); | ||
} | ||
Tuple tuple(args); | ||
Expr body = MakeConcatenate(tuple, param->axis); | ||
splitted[i] = StopFusion(body); | ||
} | ||
tvm::Array<Expr> tupleArgs(splitted); | ||
Tuple tuple(tupleArgs); | ||
return MakeConcatenate(tuple, param->axis); | ||
} | ||
return post; | ||
} | ||
|
||
private: | ||
const int max_function_args_; | ||
const Op& concat_op_; | ||
}; | ||
|
||
Expr SplitArgs(const Expr& expr, int max_function_args) { | ||
auto rewriter = ArgumentSplitter(max_function_args); | ||
return PostOrderRewrite(expr, &rewriter); | ||
} | ||
|
||
namespace transform { | ||
|
||
Pass SplitArgs(int max_function_args) { | ||
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = | ||
[=](Function f, IRModule m, PassContext pc) { | ||
return Downcast<Function>(SplitArgs(f, max_function_args)); | ||
}; | ||
return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs); | ||
|
||
} // namespace transform | ||
|
||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
import numpy as np | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay import transform | ||
from tvm.relay.build_module import bind_params_by_name | ||
from tvm.relay.testing import run_infer_type, create_workload | ||
|
||
|
||
def run_opt_pass(expr, opt_pass): | ||
assert isinstance(opt_pass, tvm.transform.Pass) | ||
|
||
mod = tvm.IRModule.from_expr(expr) | ||
mod = relay.transform.InferType()(mod) | ||
mod = opt_pass(mod) | ||
entry = mod["main"] | ||
return entry if isinstance(expr, relay.Function) else entry.body | ||
|
||
|
||
def test_split_concat_metal(): | ||
shape = (1, 1, 1, 3) | ||
dtype = "float32" | ||
axis = 1 | ||
inputs = [] | ||
for i in range(100): | ||
inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) | ||
|
||
def before(): | ||
inp = relay.Tuple(inputs) | ||
return relay.op.concatenate(inp, axis) | ||
|
||
def expected(): | ||
limit = tvm.target.Target("metal").max_function_args - 1 # one buffer with output | ||
splitNum = int(len(inputs) / limit) | ||
if len(inputs) % limit > 0: | ||
splitNum += 1 | ||
|
||
splitted = [] | ||
for i in range(splitNum): | ||
startIdx = i * limit | ||
argsCount = min(limit, len(inputs) - startIdx) | ||
args = [] | ||
for j in range(argsCount): | ||
args.append(inputs[j + startIdx]) | ||
t = relay.Tuple(args) | ||
concat = relay.op.concatenate(t, axis) | ||
splitted.append(relay.annotation.stop_fusion(concat)) | ||
inp = relay.Tuple(splitted) | ||
return relay.op.concatenate(inp, axis) | ||
|
||
# the fold constant should work on any context. | ||
res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("metal").max_function_args)) | ||
exp = run_opt_pass(expected(), transform.InferType()) | ||
assert tvm.ir.structural_equal(res, exp) | ||
|
||
|
||
def test_split_concat_cuda(): | ||
shape = (1, 1, 1, 3) | ||
dtype = "float32" | ||
axis = 1 | ||
inputs = [] | ||
for i in range(100): | ||
inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) | ||
|
||
def before(): | ||
inp = relay.Tuple(inputs) | ||
return relay.op.concatenate(inp, axis) | ||
|
||
def expected(): | ||
inp = relay.Tuple(inputs) | ||
return relay.op.concatenate(inp, axis) | ||
|
||
# the fold constant should work on any context. | ||
res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("cuda").max_function_args)) | ||
exp = run_opt_pass(expected(), transform.InferType()) | ||
assert tvm.ir.structural_equal(res, exp) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_split_concat_metal() | ||
test_split_concat_cuda() |