Skip to content

Commit

Permalink
[4/10] Code generation for Conv2D via CMSIS-NN (apache#9331)
Browse files Browse the repository at this point in the history
This PR is for support of Conv2D via CMSIS-NN.
  • Loading branch information
ashutosh-arm authored and yangulei committed Jan 11, 2022
1 parent 4b1667f commit 8c5cdb5
Show file tree
Hide file tree
Showing 21 changed files with 1,719 additions and 89 deletions.
12 changes: 9 additions & 3 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,14 @@ def _build_function_memory_map(function_metadata):
device_max_workspace = dict()
main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR]
num_targets = len(main_func_metadata.workspace_sizes.items())
from tvm.driver import tvmc # pylint: disable=import-outside-toplevel

external_codegens = tvmc.composite_target.get_codegen_names()
func_entries = []
target_local_entries = dict()
for i in range(num_targets):
target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[target] = 0
main_target = main_func_metadata.workspace_sizes.items()[i][0]
device_max_workspace[main_target] = 0
for func_name, finfo in function_metadata.items():
if func_name == MAIN_FUNC_NAME_STR:
continue
Expand All @@ -201,8 +204,11 @@ def _build_function_memory_map(function_metadata):
"workspace_size_bytes": int(workspace_size),
}
target_local_entries[func_name].append(target_entry)
if workspace_size > device_max_workspace[target]:
if workspace_size > device_max_workspace.get(target, 0):
device_max_workspace[target] = workspace_size
# TODO(Mousius) - Remove this massive hack when Targets are unified
if target.kind.name in external_codegens:
device_max_workspace[main_target] += int(workspace_size)

for func_name, target_entries_ in target_local_entries.items():
func_entry = {
Expand Down
83 changes: 63 additions & 20 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from ...dataflow_pattern import is_constant, is_op, wildcard
from .register import register_pattern_table

tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)


def enabled():
return "cmsis-nn" in Target.list_kinds()
Expand Down Expand Up @@ -53,37 +55,85 @@ def partition_for_cmsisnn(mod, params=None, **opts):
transform.InferType(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("cmsis-nn"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
GenerateCMSISNNConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
]
)

return seq(mod)


@register_pattern_table("cmsis-nn")
def pattern_table():
"""Get the CMSIS-NN compiler pattern table."""

def softmax_pattern():
def qnn_softmax_pattern():
"""Create pattern for quantized softmax"""
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
pattern = is_op("nn.softmax")(pattern)
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
return pattern

def check_quantized_softmax(extract):
def check_qnn_softmax(pattern):
"""Check if softmax is supported by CMSIS-NN."""
dequantize_call = extract.args[0].args[0]
scale = extract.args[1].data.numpy().item(0)
zero_point = extract.args[2].data.numpy().item(0)
dequantize_call = pattern.args[0].args[0]
scale = pattern.args[1].data.numpy().item(0)
zero_point = pattern.args[2].data.numpy().item(0)

# check for dtypes of quantize and dequantize
return (
(scale == 1.0 / 256 and zero_point == -128)
and extract.attrs.out_dtype == "int8"
and pattern.attrs.out_dtype == "int8"
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def qnn_conv2d_pattern():
"""Create pattern for qnn.conv2D with optional fused relu."""
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
)
clip_or_req = req.optional(is_op("clip"))
return clip_or_req

def check_qnn_conv2d(pattern):
"""Check if the Conv2D is supported by CMSIS-NN."""
if str(pattern.op.name) == "clip":
relu = pattern
requantize = relu.args[0]
else:
requantize = pattern
requantize_input = requantize.args[0]
bias_add = None
bias_dtype = "int32"
if str(requantize_input.op.name) == "nn.bias_add":
bias_add = requantize_input
conv2d = bias_add.args[0]
bias_dtype = bias_add.args[1].checked_type.dtype
else:
conv2d = requantize_input
conv2d_input = conv2d.args[0]
conv2d_weight = conv2d.args[1]

# kernel zero_point should be 0
kernel_zp = conv2d.args[3].data.numpy()
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp

return (
conv2d.attrs.out_dtype == "int32"
and conv2d.attrs.padding[2] == 0
and conv2d.attrs.padding[3] == 0
and conv2d_input.checked_type.dtype == "int8"
and conv2d_weight.checked_type.dtype == "int8"
and pattern.checked_type.dtype == "int8"
and bias_dtype == "int32"
and all([zp == 0 for zp in kernel_zp])
)

def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
Expand All @@ -97,23 +147,16 @@ def binary_op_pattern(op):
is_constant(),
)

def check_quantized_binary_op(extract):
def check_qnn_binary_op(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
)

return [
("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
(
"cmsis-nn.quantized_mul",
binary_op_pattern("mul"),
check_quantized_binary_op,
),
(
"cmsis-nn.quantized_add",
binary_op_pattern("add"),
check_quantized_binary_op,
),
("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
]
28 changes: 8 additions & 20 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,25 +698,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {}

LoweredOutput Codegen(relay::Function func, String mod_name) {
AOTOnDemandAllocator initial_aot_allocator;
initial_aot_allocator.Run(func);

// Pre-lowering storage map and memory plan
// TODO(mbs): Why plan memory and update workspace sizes before lowering?
StorageMap initial_storage_map = initial_aot_allocator.GetStorageMap();
StaticMemoryPlan memory_plan(initial_storage_map);

IRModule mod = IRModule::FromExpr(func);

backend::FunctionInfo func_info;

if (memory_plan.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info);
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](Function func) {
IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -733,12 +716,17 @@ class AOTExecutorCodegen : public MixedModeVisitor {
auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
// created, just referencing the new expressions created from lowering
// Post-lowering storage map for writing main func
AOTOnDemandAllocator final_aot_allocator;
final_aot_allocator.Run(lowered_main_func);
storage_device_map_ = final_aot_allocator.GetStorageMap();

// TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize
StaticMemoryPlan memory_plan(storage_device_map_);
backend::FunctionInfo func_info =
tec::UpdateMainWorkspaceSize(lowered_mod, targets_, memory_plan->expr_to_storage_info);
lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info);

for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
main_signature_.push_back(tir::Var("input", DataType::Handle()));
Expand Down
168 changes: 168 additions & 0 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file extract_constant.cc
* \brief Pushes out constants within partitioned functions all the way upto main()
*/

#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/ndarray.h>

#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

/*!
* \brief This Mutator finds all functions with constants. Constants are replaced with function
* parameter variables. Constants are pushed all the way upto main().
*/
class ExtractConstantsMutator : public MixedModeMutator {
public:
explicit ExtractConstantsMutator(const IRModule& mod) : mod_(mod) {}

private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

Expr VisitExpr_(const FunctionNode* function) final {
Function func = GetRef<Function>(function);
function_to_constants_.Set(func, Array<Constant>{});
functions_.push_back(func);
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
}
return func;
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();

// Replace Constant arguments with Vars for ML Operators
// Perform this for non-main Call Nodes only
if (!functions_.empty() && call->op.as<OpNode>()) {
Array<Expr> new_args;
for (auto& arg : post_call->args) {
auto* const_arg = arg.as<ConstantNode>();
if (const_arg && !const_arg->is_scalar()) {
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
new_args.push_back(var_arg);
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
fconstants.push_back(GetRef<Constant>(const_arg));
function_to_constants_.Set(last_func, fconstants);
} else {
new_args.push_back(arg);
}
}
final_call = Call(call->op, new_args, call->attrs, {});
}

// Since the constants are kicked out of partitioned functions
// a new call to global function is needed
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
auto glob_var = GetRef<GlobalVar>(glob_var_node);
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
auto new_glob_func = VisitExpr(glob_func);
if (!new_glob_func.same_as(glob_func)) {
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
for (auto constant : function_to_constants_.at(glob_func)) {
new_args.push_back(constant);
}
final_call = Call(glob_var, new_args);
}
}

// Since the constants are kicked out of the local partitioned functions
// a new call to local function is needed
// Also, pass on the constants to the callee of this function to support nested functions
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
auto new_func = VisitExpr(func);
if (!new_func.same_as(func)) {
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
for (auto constant : function_to_constants_.at(func)) {
fconstants.push_back(constant);
Var var_arg = Var(gen_var_name(), constant->tensor_type());
new_args.push_back(var_arg);
}
function_to_constants_.Set(last_func, fconstants);
final_call = Call(new_func, new_args);
}
}

return final_call;
}

private:
/* \brief Updated module where all calls have replaced constants with new variables */
IRModule mod_;
/* \brief Maintains mapping of original function to the replaced constants */
Map<Function, Array<Constant>> function_to_constants_;
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
Array<Function> functions_;
/* \brief Keeps track of variables being created */
int var_count_ = 0;
};

/*! * \brief Kicks out all constants out of the partitioned function into main() */
IRModule ExtractConstants(const IRModule& mod) {
String func_name;
Function func;

auto extract_constants = ExtractConstantsMutator(mod);
Function main_func = Downcast<Function>(mod->Lookup("main"));
auto new_main_body = extract_constants.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
mod->Update(main_var, new_main_func);
}
return mod;
}

transform::Pass ExtractConstantsFromPartitionedFunction() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
{});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
.set_body_typed(ExtractConstantsFromPartitionedFunction);

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm
Loading

0 comments on commit 8c5cdb5

Please sign in to comment.