Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CMSIS-NN] Convert scalar constants to tensor constants #10100

Merged
merged 4 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def partition_for_cmsisnn(mod, params=None, **opts):
transform.AnnotateTarget("cmsis-nn"),
transform.PartitionGraph(),
GenerateCMSISNNConstants(),
ScalarToTensorConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
]
Expand Down Expand Up @@ -223,11 +224,23 @@ def binary_op_pattern(op):
is_constant(),
)

def check_qnn_binary_op(extract):
def check_qnn_binary_op(pattern):
"""Check if multiply is supported by CMSIS-NN."""
arg0 = pattern.args[0]
arg1 = pattern.args[1]
both_args_scalar = False
if (
isinstance(arg0, tvm.relay.expr.Constant)
and len(arg0.checked_type.shape) == 0
and isinstance(arg1, tvm.relay.expr.Constant)
and len(arg1.checked_type.shape) == 0
):
both_args_scalar = True

return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
arg0.checked_type.dtype == "int8"
and arg1.checked_type.dtype == "int8"
and not both_args_scalar
)

return [
Expand Down
125 changes: 93 additions & 32 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,88 @@ class ExtractConstantsMutator : public MixedModeMutator {
return func;
}

function_to_constants_.Set(func, Array<Constant>{});
function_to_arguments_.Set(func, Array<Expr>{});
functions_.push_back(func);
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
if (function_to_arguments_[func].size()) {
func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
}
return std::move(func);
}

// Creates new arguments from current call's arguments
// Updates constants into the caller arguments: here caller signifies caller that comprises call
// to func
Array<Expr> CreateNewCallArgsFromExtractedConstants(Call call, Function func) {
ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
Array<Expr> function_signature(function_to_arguments_[func]);

// Is func a global_function?
// main() is not registered for extracting constants
bool is_global_function = functions_.empty() ? true : false;

bool new_constants_added = false;
// This tracks arguments traversed inside function_signature
uint32_t function_signature_id = 0;
// This contains arguments including constants for the caller of this function inside which
// post_call resides.
Array<Expr> new_caller_args;
// New arguments to post_call that includes new variables representing constants extracted from
// the function
Array<Expr> new_call_args;
for (auto& arg : call->args) {
if (auto* constant = arg.as<ConstantNode>()) {
new_caller_args.push_back(arg);
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
++function_signature_id;
new_constants_added = true;
continue;
}

// Push all constants from the function_signature until a variable corresponding to the
// current argument is hit
while (function_signature_id < function_signature.size()) {
auto* constant = function_signature[function_signature_id].as<ConstantNode>();
if (constant == nullptr) {
break;
}
new_caller_args.push_back(function_signature[function_signature_id++]);
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
new_constants_added = true;
}

new_call_args.push_back(arg);
if (is_global_function || arg.as<VarNode>()) {
new_caller_args.push_back(arg);
}
++function_signature_id;
}

// Push remaining constants as new arguments
for (uint32_t i = function_signature_id; i < function_signature.size(); ++i) {
auto* constant = function_signature[i].as<ConstantNode>();
ICHECK(constant)
<< "Rest of the collected arguments should be constant in the partitioned function.";
new_caller_args.push_back(GetRef<Constant>(constant));
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
new_constants_added = true;
}

// Update the arguments of caller of local function
if (new_constants_added && !is_global_function) {
const Function& last_func = functions_.back();
Array<Expr> function_constants(function_to_arguments_[last_func]);
function_to_arguments_.Set(last_func,
tvm::runtime::Concat(function_constants, new_caller_args));
} else {
new_call_args = new_caller_args;
}

return new_call_args;
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();
Expand All @@ -81,58 +152,47 @@ class ExtractConstantsMutator : public MixedModeMutator {
// Perform this for non-main Call Nodes only
if (!functions_.empty() && call->op.as<OpNode>()) {
Array<Expr> new_args;
const Function& last_func = functions_.back();
Array<Expr> function_signature(function_to_arguments_[last_func]);
for (auto& arg : post_call->args) {
// Push all arguments including constants to maintain correct order of
// variables and constants
auto* const_arg = arg.as<ConstantNode>();
if (const_arg && !const_arg->is_scalar()) {
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
new_args.push_back(var_arg);
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
fconstants.push_back(GetRef<Constant>(const_arg));
function_to_constants_.Set(last_func, fconstants);
function_signature.push_back(arg);
} else {
if (arg.as<VarNode>()) {
function_signature.push_back(arg);
}
new_args.push_back(arg);
}
}
function_to_arguments_.Set(last_func, function_signature);
final_call = Call(call->op, new_args, call->attrs, {});
}

// Since the constants are kicked out of partitioned functions
// Since the constants are extracted from partitioned functions
// a new call to global function is needed
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
auto glob_var = GetRef<GlobalVar>(glob_var_node);
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
auto new_glob_func = VisitExpr(glob_func);
if (!new_glob_func.same_as(glob_func)) {
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
for (auto constant : function_to_constants_.at(glob_func)) {
new_args.push_back(constant);
}
auto new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), glob_func);
final_call = Call(glob_var, new_args);
}
}

// Since the constants are kicked out of the local partitioned functions
// Since the constants are extracted from the local partitioned functions
// a new call to local function is needed
// Also, pass on the constants to the callee of this function to support nested functions
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
auto new_func = VisitExpr(func);
if (!new_func.same_as(func)) {
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
for (auto constant : function_to_constants_.at(func)) {
fconstants.push_back(constant);
Var var_arg = Var(gen_var_name(), constant->tensor_type());
new_args.push_back(var_arg);
}
function_to_constants_.Set(last_func, fconstants);
final_call = Call(new_func, new_args);
}
Array<Expr> new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), func);
final_call = Call(new_func, new_args);
}

return final_call;
Expand All @@ -141,15 +201,16 @@ class ExtractConstantsMutator : public MixedModeMutator {
private:
/* \brief Updated module where all calls have replaced constants with new variables */
IRModule mod_;
/* \brief Maintains mapping of original function to the replaced constants */
Map<Function, Array<Constant>> function_to_constants_;
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
/* \brief Maintains mapping of original function to the replaced constants along with other
* arguments to retain the order in which variables are used within the function */
Map<Function, Array<Expr>> function_to_arguments_;
/* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
Array<Function> functions_;
/* \brief Keeps track of variables being created */
int var_count_ = 0;
};

/*! * \brief Kicks out all constants out of the partitioned function into main() */
/*! * \brief Extracts all constants out of the partitioned function into main() */
IRModule ExtractConstants(const IRModule& mod) {
String func_name;
Function func;
Expand All @@ -169,7 +230,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
{});
{"InferType"});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
Expand Down
Loading