Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMSIS-NN] Pad fusion with QNN Conv2D #12353

Merged
merged 5 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts):
transform.AnnotateTarget("cmsis-nn"),
transform.PartitionGraph(mod_name=mod_name),
GenerateCMSISNNConstants(),
CMSISNNFusePads(),
ScalarToTensorConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
Expand Down Expand Up @@ -91,10 +92,18 @@ def check_qnn_softmax(pattern):
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def qnn_conv2d_pattern():
"""Create pattern for qnn.conv2D with optional fused relu."""
def qnn_conv2d_pattern(with_pad):
"""Create pattern for qnn.conv2D with optional pad and/or optional fused relu."""
conv2d_input = wildcard()
if with_pad:
conv2d_input = is_op("nn.pad")(wildcard(), is_constant())
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
conv2d_input,
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
Expand Down Expand Up @@ -136,7 +145,7 @@ def check_qnn_conv2d(pattern):
):
is_depthwise = True

return (
ret = (
conv2d.attrs.out_dtype == "int32"
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
Expand All @@ -145,6 +154,36 @@ def check_qnn_conv2d(pattern):
and all([zp == 0 for zp in kernel_zp])
and (not is_depthwise or bias_add is not None)
)
return ret

def check_qnn_conv2d_pad(pattern):
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved
"""Check if the Pad followed by Conv2D is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]

# check if sum of paddings from pad() and conv2d() satisfies CMSIS-NN constraints
can_pad_be_fused = True
if isinstance(conv2d_input, tvm.relay.expr.Call) and str(conv2d_input.op.name) == "nn.pad":
pad_top, pad_left, pad_bottom, pad_right = GetEffectiveConv2DPadding(
conv2d, conv2d_input
)
# check if difference in the side paddings is 1 along each dimension
pad_w_diff = int(pad_right - pad_left)
pad_h_diff = int(pad_bottom - pad_top)
can_pad_be_fused = pad_w_diff in [0, 1] and pad_h_diff in [0, 1]

ret = check_qnn_conv2d(pattern) and can_pad_be_fused
return ret

def qnn_fully_connected_pattern():
"""Create pattern for qnn.dense with optional Relu."""
Expand Down Expand Up @@ -275,7 +314,8 @@ def check_qnn_binary_op(pattern):
)

return [
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=True), check_qnn_conv2d_pad),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(with_pad=False), check_qnn_conv2d),
("cmsis-nn.qnn_fully_connected", qnn_fully_connected_pattern(), check_qnn_fully_connected),
("cmsis-nn.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_qnn_avg_pool2d),
("cmsis-nn.qnn_max_pool2d", qnn_max_pool2d_pattern(), check_qnn_max_pool2d),
Expand Down
209 changes: 209 additions & 0 deletions src/relay/backend/contrib/cmsisnn/fuse_pads.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/backend/contrib/cmsisnn/fuse_pads.cc
* \brief Fuses pads that precede qnn.conv2d ops inside CMSIS-NN composite functions.
*/

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/ndarray.h>

#include "../../../op/make_op.h"
#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"
#include "convolutions.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

inline IntImm ToIntImm(int32_t value) { return IntImm(DataType::Int(32), value); }

/*!
* \brief From padding attributes of nn.pad and qnn.conv2d, calculates effective padding along H
* and W dimensions.
*/
Array<IntImm> GetEffectiveConv2DPadding(Expr conv2d, Expr pad) {
// pad_width: ((), (top, bottom), (left, right), ()) for NHWC layout
// conv2d_attrs->padding: (top, left, bottom, right)
auto* conv2d_call = conv2d.as<CallNode>();
auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
std::string data_layout = conv2d_attrs->data_layout.c_str();
int pos_h = data_layout.find("H");
int pos_w = data_layout.find("W");

auto* pad_call = pad.as<CallNode>();
Array<Array<Integer>> pad_width = pad_call->attrs.as<PadAttrs>()->pad_width;
int pad_top =
qnn::get_const_int(conv2d_attrs->padding[0]) + qnn::get_const_int(pad_width[pos_h][0]);
int pad_left =
qnn::get_const_int(conv2d_attrs->padding[1]) + qnn::get_const_int(pad_width[pos_w][0]);
int pad_bottom =
qnn::get_const_int(conv2d_attrs->padding[2]) + qnn::get_const_int(pad_width[pos_h][1]);
int pad_right =
qnn::get_const_int(conv2d_attrs->padding[3]) + qnn::get_const_int(pad_width[pos_w][1]);

return {ToIntImm(pad_top), ToIntImm(pad_left), ToIntImm(pad_bottom), ToIntImm(pad_right)};
}

/*!
* \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
* Then, it will fuse preceding pads with qnn.conv2d.
*/
class FusePadsMutator : public MixedModeMutator {
public:
explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {}

private:
/*!
* \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
* convolution layer to update Conv2DAttrs's padding attribute. */
void UpdateConv2DPadding(const CallNode* conv2d_call, const CallNode* pad_call,
Attrs* new_attrs) {
Array<IntImm> effective_padding =
GetEffectiveConv2DPadding(GetRef<Call>(conv2d_call), GetRef<Call>(pad_call));
int pad_top = effective_padding[0]->value;
int pad_left = effective_padding[1]->value;
int pad_bottom = effective_padding[2]->value;
int pad_right = effective_padding[3]->value;
int pad_diff_w = pad_right - pad_left;
int pad_diff_h = pad_bottom - pad_top;
ashutosh-arm marked this conversation as resolved.
Show resolved Hide resolved
bool can_pad_be_fused =
((pad_diff_w == 0 || pad_diff_w == 1) && (pad_diff_h == 0 || pad_diff_h == 1));
std::string error = "Difference on each side of a dimension should be either 0 or 1. ";
error += "Effective padding in this case: (pad_top, pad_left, pad_bottom, pad_right)=(";
error += std::to_string(pad_top);
error += ", ";
error += std::to_string(pad_left);
error += ", ";
error += std::to_string(pad_bottom);
error += ", ";
error += std::to_string(pad_right);
error += ")";
ICHECK(can_pad_be_fused) << error;

// Prepare new attrs as padding has changed
auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(conv2d_attrs->strides);
attrs->dilation = std::move(conv2d_attrs->dilation);
attrs->groups = conv2d_attrs->groups;
attrs->channels = std::move(conv2d_attrs->channels);
attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
attrs->data_layout = std::move(conv2d_attrs->data_layout);
attrs->kernel_layout = std::move(conv2d_attrs->kernel_layout);
attrs->out_layout = std::move(conv2d_attrs->out_layout);
attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
attrs->padding = {pad_top, pad_left, pad_bottom, pad_right};
*new_attrs = tvm::Attrs{attrs};
}

/*!
* \brief Identifies the sequence for qnn.conv2D and fuses the preceding nn.pad present within the
* CMSIS-NN partitioned function. */
Expr FusePadConv2d(const CallNode* conv2d_call) {
// create new paddings for qnn.conv2d
tvm::Attrs new_conv2d_attrs = conv2d_call->attrs;
Expr new_conv2d_input = conv2d_call->args[0];
if (auto* pad_call = conv2d_call->args[0].as<CallNode>()) {
if (auto* pad_call_op = pad_call->op.as<OpNode>()) {
if (pad_call_op->name == "nn.pad") {
new_conv2d_input = pad_call->args[0];
UpdateConv2DPadding(conv2d_call, pad_call, &new_conv2d_attrs);
}
}
}

// Conv2D arguments: pad's input + rest of the origin args
auto new_conv2d_args = conv2d_call->args;
new_conv2d_args.erase(new_conv2d_args.begin());
new_conv2d_args.insert(new_conv2d_args.begin(), new_conv2d_input);
Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {});
return std::move(ret_call);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr ret_call = post;
auto* post_call = post.as<CallNode>();

// Fuse nn.pad and qnn.conv2d
if (auto* conv2d_op = post_call->op.as<OpNode>()) {
if (conv2d_op->name == "qnn.conv2d") {
ret_call = FusePadConv2d(post_call);
}
}

// Identify qnn.conv2d partitioned function
if (post_call->op.as<FunctionNode>()) {
auto* func = call->op.as<FunctionNode>();
auto func_name = func->GetAttr<String>(attr::kComposite);
if (func_name.defined() && func_name == "cmsis-nn.qnn_conv2d") {
Expr new_body = VisitExpr(func->body);
Function new_func = Function(FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
ret_call = Call(new_func, post_call->args);
}
}

return ret_call;
}

private:
IRModule mod_;
};

IRModule FusePads(const IRModule& mod) {
for (auto gv : mod->GetGlobalVars()) {
Function func = Downcast<Function>(mod->Lookup(gv));

// only mutate CMSIS-NN partitioned functions
auto compiler_name = func->GetAttr<String>(attr::kCompiler);
if (!compiler_name.defined() || compiler_name != "cmsis-nn") {
continue;
}

auto fuse_pads_mutator = FusePadsMutator(mod);
auto new_func_body = fuse_pads_mutator.VisitExpr(func->body);
if (!new_func_body.same_as(func->body)) {
Function new_func =
Function(func->params, new_func_body, func->ret_type, func->type_params, func->attrs);
mod->Update(gv, new_func);
}
}
return mod;
}

transform::Pass CMSISNNFusePads() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return FusePads(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "CMSISNNFusePads", {});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.CMSISNNFusePads").set_body_typed(CMSISNNFusePads);
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GetEffectiveConv2DPadding")
.set_body_typed(GetEffectiveConv2DPadding);

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm
Loading