Skip to content

Commit

Permalink
[CMSIS-NN] Convert scalar constants to tensor constants (#10100)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutosh-arm authored Feb 2, 2022
1 parent 565e6b4 commit cb3d7e2
Show file tree
Hide file tree
Showing 6 changed files with 671 additions and 94 deletions.
19 changes: 16 additions & 3 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def partition_for_cmsisnn(mod, params=None, **opts):
transform.AnnotateTarget("cmsis-nn"),
transform.PartitionGraph(),
GenerateCMSISNNConstants(),
ScalarToTensorConstants(),
ExtractConstantsFromPartitionedFunction(),
transform.InferType(),
]
Expand Down Expand Up @@ -223,11 +224,23 @@ def binary_op_pattern(op):
is_constant(),
)

def check_qnn_binary_op(extract):
def check_qnn_binary_op(pattern):
"""Check if multiply is supported by CMSIS-NN."""
arg0 = pattern.args[0]
arg1 = pattern.args[1]
both_args_scalar = False
if (
isinstance(arg0, tvm.relay.expr.Constant)
and len(arg0.checked_type.shape) == 0
and isinstance(arg1, tvm.relay.expr.Constant)
and len(arg1.checked_type.shape) == 0
):
both_args_scalar = True

return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
arg0.checked_type.dtype == "int8"
and arg1.checked_type.dtype == "int8"
and not both_args_scalar
)

return [
Expand Down
125 changes: 93 additions & 32 deletions src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,88 @@ class ExtractConstantsMutator : public MixedModeMutator {
return func;
}

function_to_constants_.Set(func, Array<Constant>{});
function_to_arguments_.Set(func, Array<Expr>{});
functions_.push_back(func);
auto new_body = VisitExpr(func->body);
functions_.pop_back();
if (function_to_constants_[func].size()) {
if (function_to_arguments_[func].size()) {
func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
FreeTypeVars(new_body, mod_), func->attrs);
}
return std::move(func);
}

// Creates new arguments from current call's arguments
// Updates constants into the caller arguments: here caller signifies caller that comprises call
// to func
Array<Expr> CreateNewCallArgsFromExtractedConstants(Call call, Function func) {
ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end());
Array<Expr> function_signature(function_to_arguments_[func]);

// Is func a global_function?
// main() is not registered for extracting constants
bool is_global_function = functions_.empty() ? true : false;

bool new_constants_added = false;
// This tracks arguments traversed inside function_signature
uint32_t function_signature_id = 0;
// This contains arguments including constants for the caller of this function inside which
// post_call resides.
Array<Expr> new_caller_args;
// New arguments to post_call that includes new variables representing constants extracted from
// the function
Array<Expr> new_call_args;
for (auto& arg : call->args) {
if (auto* constant = arg.as<ConstantNode>()) {
new_caller_args.push_back(arg);
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
++function_signature_id;
new_constants_added = true;
continue;
}

// Push all constants from the function_signature until a variable corresponding to the
// current argument is hit
while (function_signature_id < function_signature.size()) {
auto* constant = function_signature[function_signature_id].as<ConstantNode>();
if (constant == nullptr) {
break;
}
new_caller_args.push_back(function_signature[function_signature_id++]);
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
new_constants_added = true;
}

new_call_args.push_back(arg);
if (is_global_function || arg.as<VarNode>()) {
new_caller_args.push_back(arg);
}
++function_signature_id;
}

// Push remaining constants as new arguments
for (uint32_t i = function_signature_id; i < function_signature.size(); ++i) {
auto* constant = function_signature[i].as<ConstantNode>();
ICHECK(constant)
<< "Rest of the collected arguments should be constant in the partitioned function.";
new_caller_args.push_back(GetRef<Constant>(constant));
new_call_args.push_back(Var(gen_var_name(), constant->tensor_type()));
new_constants_added = true;
}

// Update the arguments of caller of local function
if (new_constants_added && !is_global_function) {
const Function& last_func = functions_.back();
Array<Expr> function_constants(function_to_arguments_[last_func]);
function_to_arguments_.Set(last_func,
tvm::runtime::Concat(function_constants, new_caller_args));
} else {
new_call_args = new_caller_args;
}

return new_call_args;
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expr final_call = post;
auto* post_call = post.as<CallNode>();
Expand All @@ -81,58 +152,47 @@ class ExtractConstantsMutator : public MixedModeMutator {
// Perform this for non-main Call Nodes only
if (!functions_.empty() && call->op.as<OpNode>()) {
Array<Expr> new_args;
const Function& last_func = functions_.back();
Array<Expr> function_signature(function_to_arguments_[last_func]);
for (auto& arg : post_call->args) {
// Push all arguments including constants to maintain correct order of
// variables and constants
auto* const_arg = arg.as<ConstantNode>();
if (const_arg && !const_arg->is_scalar()) {
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
new_args.push_back(var_arg);
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
fconstants.push_back(GetRef<Constant>(const_arg));
function_to_constants_.Set(last_func, fconstants);
function_signature.push_back(arg);
} else {
if (arg.as<VarNode>()) {
function_signature.push_back(arg);
}
new_args.push_back(arg);
}
}
function_to_arguments_.Set(last_func, function_signature);
final_call = Call(call->op, new_args, call->attrs, {});
}

// Since the constants are kicked out of partitioned functions
// Since the constants are extracted from partitioned functions
// a new call to global function is needed
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
auto glob_var = GetRef<GlobalVar>(glob_var_node);
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
auto new_glob_func = VisitExpr(glob_func);
if (!new_glob_func.same_as(glob_func)) {
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
for (auto constant : function_to_constants_.at(glob_func)) {
new_args.push_back(constant);
}
auto new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), glob_func);
final_call = Call(glob_var, new_args);
}
}

// Since the constants are kicked out of the local partitioned functions
// Since the constants are extracted from the local partitioned functions
// a new call to local function is needed
// Also, pass on the constants to the callee of this function to support nested functions
if (auto* func_node = call->op.as<FunctionNode>()) {
Function func = GetRef<Function>(func_node);
auto new_func = VisitExpr(func);
if (!new_func.same_as(func)) {
Array<Expr> new_args = post_call->args;
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
const Function& last_func = functions_.back();
Array<Constant> fconstants(function_to_constants_[last_func]);
for (auto constant : function_to_constants_.at(func)) {
fconstants.push_back(constant);
Var var_arg = Var(gen_var_name(), constant->tensor_type());
new_args.push_back(var_arg);
}
function_to_constants_.Set(last_func, fconstants);
final_call = Call(new_func, new_args);
}
Array<Expr> new_args = CreateNewCallArgsFromExtractedConstants(GetRef<Call>(post_call), func);
final_call = Call(new_func, new_args);
}

return final_call;
Expand All @@ -141,15 +201,16 @@ class ExtractConstantsMutator : public MixedModeMutator {
private:
/* \brief Updated module where all calls have replaced constants with new variables */
IRModule mod_;
/* \brief Maintains mapping of original function to the replaced constants */
Map<Function, Array<Constant>> function_to_constants_;
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
/* \brief Maintains mapping of original function to the replaced constants along with other
* arguments to retain the order in which variables are used within the function */
Map<Function, Array<Expr>> function_to_arguments_;
/* \brief Stack of functions to determine scope while filling up function_to_arguments_ */
Array<Function> functions_;
/* \brief Keeps track of variables being created */
int var_count_ = 0;
};

/*! * \brief Kicks out all constants out of the partitioned function into main() */
/*! * \brief Extracts all constants out of the partitioned function into main() */
IRModule ExtractConstants(const IRModule& mod) {
String func_name;
Function func;
Expand All @@ -169,7 +230,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
{});
{"InferType"});
}

TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
Expand Down
Loading

0 comments on commit cb3d7e2

Please sign in to comment.