From a00d3d49b85becb64e928ad9b93c8e9dd6d78850 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Wed, 17 Nov 2021 09:51:24 +0000 Subject: [PATCH] [4/10] Code generation for Conv2D via CMSIS-NN (#9331) This PR is for support of Conv2D via CMSIS-NN. --- python/tvm/micro/model_library_format.py | 12 +- python/tvm/relay/op/contrib/cmsisnn.py | 83 +++-- src/relay/backend/aot_executor_codegen.cc | 28 +- .../contrib/cmsisnn/extract_constants.cc | 168 ++++++++++ .../contrib/cmsisnn/generate_constants.cc | 244 ++++++++++++++ .../backend/contrib/cmsisnn/relay_to_tir.cc | 174 +++++++++- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 164 +++++++++ src/relay/backend/graph_executor_codegen.cc | 2 +- src/relay/backend/interpreter.cc | 2 +- src/relay/backend/te_compiler.cc | 44 ++- src/relay/backend/te_compiler.h | 10 +- src/relay/backend/utils.h | 2 +- src/relay/backend/vm/compiler.cc | 3 +- src/target/source/codegen_c_host.h | 2 +- .../contrib/test_cmsisnn/test_binary_ops.py | 1 + .../contrib/test_cmsisnn/test_conv2d.py | 316 ++++++++++++++++++ .../test_cmsisnn/test_extract_constants.py | 179 ++++++++++ .../test_cmsisnn/test_generate_constants.py | 229 +++++++++++++ .../contrib/test_cmsisnn/test_networks.py | 12 +- .../contrib/test_cmsisnn/test_softmax.py | 6 +- tests/python/contrib/test_cmsisnn/utils.py | 127 ++++++- 21 files changed, 1719 insertions(+), 89 deletions(-) create mode 100644 src/relay/backend/contrib/cmsisnn/extract_constants.cc create mode 100644 src/relay/backend/contrib/cmsisnn/generate_constants.cc create mode 100644 tests/python/contrib/test_cmsisnn/test_conv2d.py create mode 100644 tests/python/contrib/test_cmsisnn/test_extract_constants.py create mode 100644 tests/python/contrib/test_cmsisnn/test_generate_constants.py diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 038cd0d04ff0..b69fc05ed942 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -174,11 +174,14 @@ def _build_function_memory_map(function_metadata): device_max_workspace = dict() main_func_metadata = function_metadata[MAIN_FUNC_NAME_STR] num_targets = len(main_func_metadata.workspace_sizes.items()) + from tvm.driver import tvmc # pylint: disable=import-outside-toplevel + + external_codegens = tvmc.composite_target.get_codegen_names() func_entries = [] target_local_entries = dict() for i in range(num_targets): - target = main_func_metadata.workspace_sizes.items()[i][0] - device_max_workspace[target] = 0 + main_target = main_func_metadata.workspace_sizes.items()[i][0] + device_max_workspace[main_target] = 0 for func_name, finfo in function_metadata.items(): if func_name == MAIN_FUNC_NAME_STR: continue @@ -201,8 +204,11 @@ def _build_function_memory_map(function_metadata): "workspace_size_bytes": int(workspace_size), } target_local_entries[func_name].append(target_entry) - if workspace_size > device_max_workspace[target]: + if workspace_size > device_max_workspace.get(target, 0): device_max_workspace[target] = workspace_size + # TODO(Mousius) - Remove this massive hack when Targets are unified + if target.kind.name in external_codegens: + device_max_workspace[main_target] += int(workspace_size) for func_name, target_entries_ in target_local_entries.items(): func_entry = { diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 824343e0066b..34efb1d7a162 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -24,6 +24,8 @@ from ...dataflow_pattern import is_constant, is_op, wildcard from .register import register_pattern_table +tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) + def enabled(): return "cmsis-nn" in Target.list_kinds() @@ -53,11 +55,12 @@ def partition_for_cmsisnn(mod, params=None, **opts): transform.InferType(), transform.MergeComposite(pattern_table()), transform.AnnotateTarget("cmsis-nn"), - transform.MergeCompilerRegions(), transform.PartitionGraph(), + GenerateCMSISNNConstants(), + ExtractConstantsFromPartitionedFunction(), + transform.InferType(), ] ) - return seq(mod) @@ -65,25 +68,72 @@ def partition_for_cmsisnn(mod, params=None, **opts): def pattern_table(): """Get the CMSIS-NN compiler pattern table.""" - def softmax_pattern(): + def qnn_softmax_pattern(): + """Create pattern for quantized softmax""" pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) pattern = is_op("nn.softmax")(pattern) pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant()) return pattern - def check_quantized_softmax(extract): + def check_qnn_softmax(pattern): """Check if softmax is supported by CMSIS-NN.""" - dequantize_call = extract.args[0].args[0] - scale = extract.args[1].data.numpy().item(0) - zero_point = extract.args[2].data.numpy().item(0) + dequantize_call = pattern.args[0].args[0] + scale = pattern.args[1].data.numpy().item(0) + zero_point = pattern.args[2].data.numpy().item(0) # check for dtypes of quantize and dequantize return ( (scale == 1.0 / 256 and zero_point == -128) - and extract.attrs.out_dtype == "int8" + and pattern.attrs.out_dtype == "int8" and dequantize_call.args[0].checked_type.dtype == "int8" ) + def qnn_conv2d_pattern(): + """Create pattern for qnn.conv2D with optional fused relu.""" + qnn_conv2d = is_op("qnn.conv2d")( + wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant() + ) + bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant()) + req = is_op("qnn.requantize")( + qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant() + ) + clip_or_req = req.optional(is_op("clip")) + return clip_or_req + + def check_qnn_conv2d(pattern): + """Check if the Conv2D is supported by CMSIS-NN.""" + if str(pattern.op.name) == "clip": + relu = pattern + requantize = relu.args[0] + else: + requantize = pattern + requantize_input = requantize.args[0] + bias_add = None + bias_dtype = "int32" + if str(requantize_input.op.name) == "nn.bias_add": + bias_add = requantize_input + conv2d = bias_add.args[0] + bias_dtype = bias_add.args[1].checked_type.dtype + else: + conv2d = requantize_input + conv2d_input = conv2d.args[0] + conv2d_weight = conv2d.args[1] + + # kernel zero_point should be 0 + kernel_zp = conv2d.args[3].data.numpy() + kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp + + return ( + conv2d.attrs.out_dtype == "int32" + and conv2d.attrs.padding[2] == 0 + and conv2d.attrs.padding[3] == 0 + and conv2d_input.checked_type.dtype == "int8" + and conv2d_weight.checked_type.dtype == "int8" + and pattern.checked_type.dtype == "int8" + and bias_dtype == "int32" + and all([zp == 0 for zp in kernel_zp]) + ) + def binary_op_pattern(op): """Matches QNN binary operation""" return is_op(f"qnn.{op}")( @@ -97,7 +147,7 @@ def binary_op_pattern(op): is_constant(), ) - def check_quantized_binary_op(extract): + def check_qnn_binary_op(extract): """Check if multiply is supported by CMSIS-NN.""" return ( extract.args[0].checked_type.dtype == "int8" @@ -105,15 +155,8 @@ def check_quantized_binary_op(extract): ) return [ - ("cmsis-nn.quantized_softmax", softmax_pattern(), check_quantized_softmax), - ( - "cmsis-nn.quantized_mul", - binary_op_pattern("mul"), - check_quantized_binary_op, - ), - ( - "cmsis-nn.quantized_add", - binary_op_pattern("add"), - check_quantized_binary_op, - ), + ("cmsis-nn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax), + ("cmsis-nn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d), + ("cmsis-nn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op), + ("cmsis-nn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op), ] diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index fde1de061bfa..9e2eb8dd527d 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -698,25 +698,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} LoweredOutput Codegen(relay::Function func, String mod_name) { - AOTOnDemandAllocator initial_aot_allocator; - initial_aot_allocator.Run(func); - - // Pre-lowering storage map and memory plan - // TODO(mbs): Why plan memory and update workspace sizes before lowering? - StorageMap initial_storage_map = initial_aot_allocator.GetStorageMap(); - StaticMemoryPlan memory_plan(initial_storage_map); - IRModule mod = IRModule::FromExpr(func); - - backend::FunctionInfo func_info; - - if (memory_plan.defined()) { - // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize - func_info = tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan->expr_to_storage_info); - mod = WithAttr(mod, "main_func_info", func_info); - } - - IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](Function func) { + IRModule lowered_mod = tec::LowerTEPass(mod_name, [this](BaseFunc func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -733,12 +716,17 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); - // Post-lowering storage map for writing main func - this should be the same map as previously - // created, just referencing the new expressions created from lowering + // Post-lowering storage map for writing main func AOTOnDemandAllocator final_aot_allocator; final_aot_allocator.Run(lowered_main_func); storage_device_map_ = final_aot_allocator.GetStorageMap(); + // TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize + StaticMemoryPlan memory_plan(storage_device_map_); + backend::FunctionInfo func_info = + tec::UpdateMainWorkspaceSize(lowered_mod, targets_, memory_plan->expr_to_storage_info); + lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info); + for (auto input : lowered_main_func->params) { input_vars_.push_back(input); main_signature_.push_back(tir::Var("input", DataType::Handle())); diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc new file mode 100644 index 000000000000..5ed23ad1ad6a --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -0,0 +1,168 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file extract_constant.cc + * \brief Pushes out constants within partitioned functions all the way upto main() + */ + +#include +#include +#include +#include + +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +/*! + * \brief This Mutator finds all functions with constants. Constants are replaced with function + * parameter variables. Constants are pushed all the way upto main(). + */ +class ExtractConstantsMutator : public MixedModeMutator { + public: + explicit ExtractConstantsMutator(const IRModule& mod) : mod_(mod) {} + + private: + String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); } + + Expr VisitExpr_(const FunctionNode* function) final { + Function func = GetRef(function); + function_to_constants_.Set(func, Array{}); + functions_.push_back(func); + auto new_body = VisitExpr(func->body); + functions_.pop_back(); + if (function_to_constants_[func].size()) { + func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), + func->attrs); + } + return func; + } + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + Expr final_call = post; + auto* post_call = post.as(); + + // Replace Constant arguments with Vars for ML Operators + // Perform this for non-main Call Nodes only + if (!functions_.empty() && call->op.as()) { + Array new_args; + for (auto& arg : post_call->args) { + auto* const_arg = arg.as(); + if (const_arg && !const_arg->is_scalar()) { + Var var_arg = Var(gen_var_name(), const_arg->tensor_type()); + new_args.push_back(var_arg); + const Function& last_func = functions_.back(); + Array fconstants(function_to_constants_[last_func]); + fconstants.push_back(GetRef(const_arg)); + function_to_constants_.Set(last_func, fconstants); + } else { + new_args.push_back(arg); + } + } + final_call = Call(call->op, new_args, call->attrs, {}); + } + + // Since the constants are kicked out of partitioned functions + // a new call to global function is needed + if (auto* glob_var_node = post_call->op.as()) { + auto glob_var = GetRef(glob_var_node); + auto glob_func = Downcast(mod_->Lookup(glob_var)); + auto new_glob_func = VisitExpr(glob_func); + if (!new_glob_func.same_as(glob_func)) { + mod_->Update(glob_var, Downcast(new_glob_func)); + Array new_args = post_call->args; + ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end()); + for (auto constant : function_to_constants_.at(glob_func)) { + new_args.push_back(constant); + } + final_call = Call(glob_var, new_args); + } + } + + // Since the constants are kicked out of the local partitioned functions + // a new call to local function is needed + // Also, pass on the constants to the callee of this function to support nested functions + if (auto* func_node = call->op.as()) { + Function func = GetRef(func_node); + auto new_func = VisitExpr(func); + if (!new_func.same_as(func)) { + Array new_args = post_call->args; + ICHECK(function_to_constants_.find(func) != function_to_constants_.end()); + const Function& last_func = functions_.back(); + Array fconstants(function_to_constants_[last_func]); + for (auto constant : function_to_constants_.at(func)) { + fconstants.push_back(constant); + Var var_arg = Var(gen_var_name(), constant->tensor_type()); + new_args.push_back(var_arg); + } + function_to_constants_.Set(last_func, fconstants); + final_call = Call(new_func, new_args); + } + } + + return final_call; + } + + private: + /* \brief Updated module where all calls have replaced constants with new variables */ + IRModule mod_; + /* \brief Maintains mapping of original function to the replaced constants */ + Map> function_to_constants_; + /* \brief Stack of functions to determine scope while filling up function_to_constants_ */ + Array functions_; + /* \brief Keeps track of variables being created */ + int var_count_ = 0; +}; + +/*! * \brief Kicks out all constants out of the partitioned function into main() */ +IRModule ExtractConstants(const IRModule& mod) { + String func_name; + Function func; + + auto extract_constants = ExtractConstantsMutator(mod); + Function main_func = Downcast(mod->Lookup("main")); + auto new_main_body = extract_constants.VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + auto main_var = mod->GetGlobalVar("main"); + auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, + main_func->type_params, main_func->attrs); + mod->Update(main_var, new_main_func); + } + return mod; +} + +transform::Pass ExtractConstantsFromPartitionedFunction() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); }; + return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction", + {}); +} + +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction") + .set_body_typed(ExtractConstantsFromPartitionedFunction); + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc new file mode 100644 index 000000000000..0231e8b52117 --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -0,0 +1,244 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file generate_constant.cc + * \brief Generates quantization parameters needed by CMSIS-NN + */ + +#include +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +/*! + * \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D. + * It will substitute original Conv2D's weight zero point and original Requantize's input zero point + * with CMSIS-NN's quantization parameters. + * https://github.com/tensorflow/tflite-micro/blob/0f40100fc60276e9f345c23282de3baf19a78059/tensorflow/lite/kernels/internal/quantization_util.cc#L53 + */ +class GenerateConstantsMutator : public MixedModeMutator { + public: + explicit GenerateConstantsMutator(const IRModule& mod) : mod_(mod) {} + + private: + /*! * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */ + Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) { + auto attrs = make_object(); + attrs->strides = std::move(conv2d_attrs->strides); + attrs->padding = std::move(conv2d_attrs->padding); + attrs->dilation = std::move(conv2d_attrs->dilation); + attrs->groups = conv2d_attrs->groups; + attrs->channels = std::move(conv2d_attrs->channels); + attrs->kernel_size = std::move(conv2d_attrs->kernel_size); + attrs->data_layout = std::move(conv2d_attrs->data_layout); + attrs->kernel_layout = runtime::String("OHWI"); + attrs->out_layout = std::move(conv2d_attrs->out_layout); + attrs->out_dtype = std::move(conv2d_attrs->out_dtype); + *new_attrs = tvm::Attrs{attrs}; + + std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); + int pos_o = kernel_layout.find("O"); + int pos_h = kernel_layout.find("H"); + int pos_w = kernel_layout.find("W"); + int pos_i = kernel_layout.find("I"); + + IRModule kernel_module; + auto func_body = MakeTranspose( + kernel_expr, {Integer(pos_o), Integer(pos_h), Integer(pos_w), Integer(pos_i)}); + auto kernel_func = + Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module)); + GlobalVar kernel_var("main"); + kernel_module->Add(kernel_var, kernel_func); + kernel_module = relay::transform::FoldConstant()(kernel_module); + kernel_func = Downcast(kernel_module->Lookup("main")); + return kernel_func->body; + } + + /*! * \brief Performs weight transpose and substitutes existing constants in the composite + * function for Conv2D with CMSIS-NN Requantize constants */ + Expr GenerateConv2dRequantConstants(const Expr& expr) { + const CallNode* clip_call = nullptr; + const CallNode* requantize_call = nullptr; + const CallNode* bias_add_call = nullptr; + const CallNode* conv2d_call = nullptr; + auto* final_call = expr.as(); + auto* final_op = final_call->op.as(); + if (final_op->name == "clip") { + clip_call = final_call; + requantize_call = clip_call->args[0].as(); + } else { + requantize_call = final_call; + } + auto* requantize_input = requantize_call->args[0].as(); + auto* requantize_input_op = requantize_input->op.as(); + if (requantize_input_op->name == "nn.bias_add") { + bias_add_call = requantize_input; + conv2d_call = bias_add_call->args[0].as(); + } else { + conv2d_call = requantize_input; + } + + // Transpose weights: HWIO -> OHWI + auto* conv2d_attrs = conv2d_call->attrs.as(); + tvm::Attrs new_conv2d_attrs; + Expr transposed_kernel = + ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs); + + // Obtain input and output scales from Relay's Requantization + int64_t out_channels = conv2d_attrs->channels.as()->value; + float output_scale = GetScalarFromConstant(requantize_call->args[3]); + auto input_scales = tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]); + ICHECK(input_scales.size() == static_cast(out_channels)); + + // Calculate requantization multiplier and shift + Device dev{DLDeviceType::kDLCPU, 0}; + runtime::NDArray multiplier_nda = + runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev); + runtime::NDArray shift_nda = runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev); + int32_t* multiplier = static_cast(multiplier_nda->data); + int32_t* shift = static_cast(shift_nda->data); + for (int i = 0; i < out_channels; ++i) { + double effective_output_scale = + static_cast(input_scales[i]) / static_cast(output_scale); + std::tie(*(multiplier + i), *(shift + i)) = + tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale); + } + + // Create constants from requantization multiplier and shift + Constant multiplier_const(multiplier_nda); + Constant shift_const(shift_nda); + + // Convert scale scalars into Constants + // Scales are expected as Constants by following passes + Expr weight_scale = conv2d_call->args[5]; + Expr req_inp_scale = requantize_call->args[1]; + if (out_channels == 1) { + runtime::NDArray weight_scale_nda = + runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev); + float* weight_scale_p = static_cast(weight_scale_nda->data); + *weight_scale_p = GetScalarFromConstant(weight_scale); + weight_scale = Constant(weight_scale_nda); + + runtime::NDArray req_inp_scale_nda = + runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev); + float* req_inp_scale_p = static_cast(req_inp_scale_nda->data); + *req_inp_scale_p = GetScalarFromConstant(req_inp_scale); + req_inp_scale = Constant(req_inp_scale_nda); + } + + // Replace existing weights (HWIO) with the transposed ones (OHWI) + // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier + // Substitute Requantize input_zero_point with CMSIS-NN shift + // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc + Array conv2d_args = {conv2d_call->args[0], transposed_kernel, conv2d_call->args[2], + multiplier_const, conv2d_call->args[4], weight_scale}; + Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}); + if (bias_add_call) { + ret_call = + Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, {}); + } + Array requantize_args = {ret_call, req_inp_scale, shift_const, requantize_call->args[3], + requantize_call->args[4]}; + ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {}); + if (clip_call) { + ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}); + } + return ret_call; + } + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + Expr final_call = post; + auto* post_call = post.as(); + + auto* global_var = call->op.as(); + if (global_var) { + // Update to global function call needed because the body changes while + // generating new constants + Function func = Downcast(mod_->Lookup(global_var->name_hint)); + Expr new_body = VisitExpr(func->body); + if (!new_body.same_as(func->body)) { + Function new_func = Function(FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); + mod_->Update(GetRef(global_var), new_func); + final_call = Call(GetRef(global_var), post_call->args); + } + } + + // Recreate composite function and corresponding call + // Updated composite function contains CMSIS-NN quantized multiplier and shift constants + if (call->op.as()) { + auto* func = call->op.as(); + auto func_name = func->GetAttr(attr::kComposite); + if (func_name.defined() && func_name == "cmsis-nn.qnn_conv2d") { + Expr new_body = GenerateConv2dRequantConstants(func->body); + Function new_func = Function(FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); + final_call = Call(new_func, post_call->args); + } + } + + return final_call; + } + + private: + IRModule mod_; +}; + +IRModule GenerateConstants(const IRModule& mod) { + String func_name; + Function func; + + // Introduces CMSIS-NN constants before the call to the external Relay function + auto generate_constants = GenerateConstantsMutator(mod); + Function main_func = Downcast(mod->Lookup("main")); + auto new_main_body = generate_constants.VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + auto main_var = mod->GetGlobalVar("main"); + auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, + main_func->type_params, main_func->attrs); + mod->Update(main_var, new_main_func); + } + + return mod; +} + +transform::Pass GenerateCMSISNNConstants() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, transform::PassContext pc) { return GenerateConstants(m); }; + return tvm::transform::CreateModulePass(pass_func, 0, "GenerateCMSISNNConstants", {}); +} + +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GenerateCMSISNNConstants") + .set_body_typed(GenerateCMSISNNConstants); + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index bd0ac52330d5..1b639dd36e9d 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -18,6 +18,7 @@ * under the License. */ #include +#include #include #include #include @@ -37,7 +38,9 @@ namespace cmsisnn { class RelayToTIRVisitor : public MixedModeMutator { public: explicit RelayToTIRVisitor(IRModule ir_module, Target target) - : ir_module_(ir_module), target_(target) {} + : ir_module_(ir_module), target_(target) { + context_buffer_id_ = 0; + } IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); @@ -58,7 +61,9 @@ class RelayToTIRVisitor : public MixedModeMutator { inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); } void CreatePrimFuncForExtern(const GlobalVar& global_var, Array func_signature, - tvm::Array call_extern_args) { + tvm::Array call_extern_args, + std::string context_buffer_name = "NULL", + int context_buffer_size = 0) { Map dict_attrs; dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint); dict_attrs.Set(tvm::attr::kTarget, target_); @@ -67,16 +72,153 @@ class RelayToTIRVisitor : public MixedModeMutator { tir::Stmt body = tir::Evaluate( tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args)); + if (context_buffer_size) { + tir::Var buffer_var(context_buffer_name, + PointerType(PrimType(DataType::Int(8)), "global.workspace")); + body = tir::Allocate(buffer_var, DataType::Int(8), {context_buffer_size}, tir::const_true(), + body); + body = + tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, target_->kind->device_type, body); + body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body); + } + tir::PrimFunc replacement_func(func_signature, body, VoidType(), Map(), DictAttrs(dict_attrs)); ir_module_->Add(global_var, replacement_func); } + Array CMSISNNDimensions(const Array& shape) { + ICHECK(shape.size() == 4) << "Supports only CMSIS-NN shapes of dimension 4."; + return Array{ToArg(qnn::get_const_int(shape[0])), ToArg(qnn::get_const_int(shape[1])), + ToArg(qnn::get_const_int(shape[2])), + ToArg(qnn::get_const_int(shape[3]))}; + } + + void EmitConv2D(const GlobalVar& global_var, const Expr& expr) { + const CallNode* clip_call = nullptr; + const CallNode* requantize_call = nullptr; + const CallNode* bias_add_call = nullptr; + const CallNode* conv2d_call = nullptr; + const CallNode* final_call = expr.as(); + const OpNode* final_op = final_call->op.as(); + if (final_op->name == "clip") { + clip_call = final_call; + requantize_call = clip_call->args[0].as(); + } else { + requantize_call = final_call; + } + const CallNode* requantize_input = requantize_call->args[0].as(); + const OpNode* requantize_input_op = requantize_input->op.as(); + if (requantize_input_op->name == "nn.bias_add") { + bias_add_call = requantize_input; + conv2d_call = bias_add_call->args[0].as(); + } else { + conv2d_call = requantize_input; + } + + // TIR variables are created in the order they appear in the Relay partitioned function + // %1 = qnn.conv2d(%input, %weight_const_0, input_zero_point_scalar, + // %cmsisnn_multiplier_const_1, %input_scale_scalar, %weight_scale_const_2) + // %2 = nn.bias_add(%1, %bias_const_3, axis=3) + // %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5, + // %output_scale_scalar, %output_zero_point_scalar) + // clip(%3, a_min=%min_scalar, a_max=%max_scalar) + tir::Var input("input", DataType::Handle(8)); + tir::Var filter("filter", DataType::Handle(8)); + tir::Var multiplier("multiplier", DataType::Handle(32)); + tir::Var filter_scale("filter_scale", DataType::Handle(32)); + tir::Var bias("bias", DataType::Handle(32)); + tir::Var input_scale("input_scale", DataType::Handle(32)); + tir::Var shift("shift", DataType::Handle(32)); + tir::Var output("output", DataType::Handle(8)); + + // Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern + // https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50 + + // prepare cmsis_nn_conv_params + const Conv2DAttrs* conv2d_attrs = conv2d_call->attrs.as(); + int32_t input_offset = -GetScalarFromConstant(conv2d_call->args[2]); + int32_t output_offset = GetScalarFromConstant(requantize_call->args[4]); + int32_t stride_w = qnn::get_const_int(conv2d_attrs->strides[1]); + int32_t stride_h = qnn::get_const_int(conv2d_attrs->strides[0]); + int32_t padding_w = qnn::get_const_int(conv2d_attrs->padding[1]); + int32_t padding_h = qnn::get_const_int(conv2d_attrs->padding[0]); + int32_t dilation_w = qnn::get_const_int(conv2d_attrs->dilation[1]); + int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]); + int32_t clip_min, clip_max; + if (clip_call) { + const ClipAttrs* clip_attrs = clip_call->attrs.as(); + clip_min = clip_attrs->a_min; + clip_max = clip_attrs->a_max; + } else { + clip_min = -128; + clip_max = 127; + } + + tvm::Array call_ext_args = {tir::StringImm("arm_convolve_wrapper_s8"), input, filter, + multiplier}; + if (bias_add_call) { + call_ext_args.push_back(bias); + } + call_ext_args.push_back(shift); + call_ext_args.push_back(output); + + tvm::Array scalar_args = {ToArg(input_offset), ToArg(output_offset), ToArg(stride_w), + ToArg(stride_h), ToArg(padding_w), ToArg(padding_h), + ToArg(dilation_w), ToArg(dilation_h), ToArg(clip_min), + ToArg(clip_max)}; + + // cmsis_nn_dims *input_dims (NHWC) + Array input_shape = conv2d_call->args[0]->type_as()->shape; + Array input_dims = CMSISNNDimensions(input_shape); + + // cmsis_nn_dims *filter_dims (OHWI) + Array filter_shape = conv2d_call->args[1]->type_as()->shape; + Array filter_dims = CMSISNNDimensions(filter_shape); + + // cmsis_nn_dims *bias_dims (1,1,1,output_channels) + Array bias_shape{1, 1, 1, filter_shape[0]}; + Array bias_dims = CMSISNNDimensions(bias_shape); + + // cmsis_nn_dims *output_dims (NHWC) + Array output_shape = conv2d_call->type_as()->shape; + Array output_dims = CMSISNNDimensions(output_shape); + + // https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367 + std::string context_buffer_name = "NULL"; + size_t context_buffer_size = + (2 * qnn::get_const_int(input_shape[3]) * qnn::get_const_int(filter_shape[2]) * + qnn::get_const_int(filter_shape[1]) * (int32_t)sizeof(int16_t)); + if (context_buffer_size) { + context_buffer_name = "context_buffer_" + std::to_string(context_buffer_id_++); + } + tvm::Array context_buffer_args = {tir::StringImm(context_buffer_name), + ToArg(context_buffer_size)}; + + scalar_args = tvm::runtime::Concat(context_buffer_args, scalar_args); + scalar_args = tvm::runtime::Concat(scalar_args, input_dims); + scalar_args = tvm::runtime::Concat(scalar_args, filter_dims); + scalar_args = tvm::runtime::Concat(scalar_args, bias_dims); + scalar_args = tvm::runtime::Concat(scalar_args, output_dims); + call_ext_args = tvm::runtime::Concat(call_ext_args, scalar_args); + + Array func_signature{input, filter, multiplier, filter_scale}; + if (bias_add_call) { + func_signature.push_back(bias); + } + func_signature.push_back(input_scale); + func_signature.push_back(shift); + func_signature.push_back(output); + + CreatePrimFuncForExtern(global_var, func_signature, call_ext_args, context_buffer_name, + context_buffer_size); + } + void EmitSoftMax(const GlobalVar& global_var, const Expr& expr) { - auto* quantize_call = expr.as(); - auto* softmax_call = quantize_call->args[0].as(); - auto* dequant_call = softmax_call->args[0].as(); + const CallNode* quantize_call = expr.as(); + const CallNode* softmax_call = quantize_call->args[0].as(); + const CallNode* dequant_call = softmax_call->args[0].as(); const float quant_scale = GetScalarFromConstant(dequant_call->args[1]); // assuming layout as NHWC @@ -91,16 +233,15 @@ class RelayToTIRVisitor : public MixedModeMutator { // calculate multiplier and shift for CMSIS-NN softmax API // Note: TensorFlow Lite Micro assumptions // Output zero point and scale are fixed to -128 and 1 / 256 + // kScaledDiffIntegerBits, kInputBits, kBeta are described on the following github page // https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47 - double beta = 1.0; - int32_t input_bits = 5; - double beta_multiplier = (beta * quant_scale * (1 << (31 - input_bits))); + double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits))); beta_multiplier = std::min(beta_multiplier, (1ll << 31) - 1.0); auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier); int32_t mult = std::get<0>(mult_shift_pair); int32_t shift = std::get<1>(mult_shift_pair); - int32_t diff_min = (1 << 5) - 1; - diff_min <<= (31 - 5); + int32_t diff_min = (1 << kScaledDiffIntegerBits) - 1; + diff_min <<= (31 - kScaledDiffIntegerBits); diff_min >>= shift; diff_min *= -1; @@ -250,15 +391,18 @@ class RelayToTIRVisitor : public MixedModeMutator { GlobalVar new_global_var(func_name.value()); new_global_var->checked_type_ = composite_func->checked_type(); - if (comp_name == "cmsis-nn.quantized_softmax") { + if (comp_name == "cmsis-nn.qnn_softmax") { EmitSoftMax(new_global_var, composite_func->body); } - if (comp_name == "cmsis-nn.quantized_mul") { + if (comp_name == "cmsis-nn.qnn_mul") { EmitMul(new_global_var, composite_func->body); } - if (comp_name == "cmsis-nn.quantized_add") { + if (comp_name == "cmsis-nn.qnn_add") { EmitAdd(new_global_var, composite_func->body); } + if (comp_name == "cmsis-nn.qnn_conv2d") { + EmitConv2D(new_global_var, composite_func->body); + } Array args; for (const auto& arg : call->args) { @@ -273,6 +417,10 @@ class RelayToTIRVisitor : public MixedModeMutator { } private: + static constexpr int32_t kScaledDiffIntegerBits = 5; + static constexpr int32_t kInputBits = 5; + static constexpr double kBeta = 1.0; + int32_t context_buffer_id_; IRModule ir_module_; Target target_; }; diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 7350107d186c..b243af6c4d5f 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -41,6 +41,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; + decl_stream << "#include \n"; CodeGenCHost::Init(output_ssa, emit_asserts, target_str); } @@ -50,6 +51,169 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { * \return string of code that offloads a subgraph to the Cortex-M */ void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } + + private: + /*! * \brief Emit the CMSIS-NN context buffer */ + void VisitStmt_(const AllocateNode* op) { + context_buffer_name_ = op->buffer_var->name_hint; + context_buffer_size_ = op->constant_allocation_size(); + CodeGenC::VisitStmt_(op); + } + + /*! * \brief Emits CMSIS-NN APIs for every call_extern */ + void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + if (!op->op.same_as(builtin::call_extern())) { + CodeGenCHost::VisitExpr_(op, os); + return; + } + std::string cmsis_func_name = op->args[0].as()->value; + if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == "arm_elementwise_mul_s8" || + cmsis_func_name == "arm_elementwise_add_s8") { + CodeGenC::VisitExpr_(op, os); + } else if (cmsis_func_name == "arm_convolve_wrapper_s8") { + EmitConv2D(op); + } + return; + } + + /*! * \brief Emits cmsis_nn_context struct */ + std::string EmitCMSISNNContext(std::ostream& os, std::string buf_name, int buf_size) { + std::string struct_name = "context"; + PrintIndent(); + os << "cmsis_nn_context " << struct_name << "= {" << buf_name << "," << buf_size << "};\n"; + return struct_name; + } + + /*! * \brief Emits cmsis_nn_conv_params struct */ + std::string EmitCMSISNNConvParams(std::ostream& os, int32_t input_offset, int32_t output_offset, + int32_t stride_w, int32_t stride_h, int32_t padding_w, + int32_t padding_h, int32_t dilation_w, int32_t dilation_h, + int32_t clip_min, int32_t clip_max) { + std::string struct_name = "conv_params"; + PrintIndent(); + os << "cmsis_nn_tile stride = {" << stride_w << "," << stride_h << "};\n"; + PrintIndent(); + os << "cmsis_nn_tile padding = {" << padding_w << "," << padding_h << "};\n"; + PrintIndent(); + os << "cmsis_nn_tile dilation = {" << dilation_w << "," << dilation_h << "};\n"; + PrintIndent(); + os << "cmsis_nn_activation activation = {" << clip_min << "," << clip_max << "};\n"; + PrintIndent(); + os << "cmsis_nn_conv_params " << struct_name << " = {" << input_offset << ", " << output_offset + << ", stride, padding, dilation, activation};\n"; + return struct_name; + } + + /*! * \brief Emits cmsis_nn_per_channel_quant_params struct */ + std::string EmitCMSISNNPerChannelQuantParams(std::ostream& os, std::string multiplier, + std::string shift) { + std::string struct_name = "quant_params"; + PrintIndent(); + os << "cmsis_nn_per_channel_quant_params " << struct_name << " = {" << multiplier << ", " + << shift << "};\n"; + return struct_name; + } + + /*! * \brief Emits cmsis_nn_dims struct */ + std::string EmitCMSISNNDims(std::ostream& os, std::string tensor_type, int32_t n, int32_t h, + int32_t w, int32_t c) { + std::string struct_name = tensor_type + "_dims"; + PrintIndent(); + os << "cmsis_nn_dims " << struct_name << " = {" << n << "," << h << "," << w << "," << c + << "};\n"; + return struct_name; + } + + /*! * \brief Emits CMSIS-NN APIs for every call_extern */ + void EmitConv2D(const CallNode* op) { + static const int max_num_args = 35; + std::string cmsis_func_name = op->args[0].as()->value; + + bool bias_enabled = false; + if (op->args.size() == max_num_args) { + bias_enabled = true; + } + + auto get_var_name = [](const CallNode* op, int id) { + return op->args[id].as()->name_hint.c_str(); + }; + auto get_arg_value = [](const CallNode* op, int id) { + return op->args[id].as()->value; + }; + int arg_id = 0; + std::string input_data = get_var_name(op, ++arg_id); + std::string filter_data = get_var_name(op, ++arg_id); + std::string multiplier = get_var_name(op, ++arg_id); + std::string bias_data("0x0"); + if (bias_enabled) { + bias_data = get_var_name(op, ++arg_id); + } + std::string shift = get_var_name(op, ++arg_id); + std::string output_data = get_var_name(op, ++arg_id); + + std::string context_buffer_name = op->args[++arg_id].as()->value; + int context_buffer_size = get_arg_value(op, ++arg_id); + int input_offset = get_arg_value(op, ++arg_id); + int output_offset = get_arg_value(op, ++arg_id); + int stride_w = get_arg_value(op, ++arg_id); + int stride_h = get_arg_value(op, ++arg_id); + int padding_w = get_arg_value(op, ++arg_id); + int padding_h = get_arg_value(op, ++arg_id); + int dilation_w = get_arg_value(op, ++arg_id); + int dilation_h = get_arg_value(op, ++arg_id); + int clip_min = get_arg_value(op, ++arg_id); + int clip_max = get_arg_value(op, ++arg_id); + int input_n = get_arg_value(op, ++arg_id); + int input_h = get_arg_value(op, ++arg_id); + int input_w = get_arg_value(op, ++arg_id); + int input_c = get_arg_value(op, ++arg_id); + int filter_n = get_arg_value(op, ++arg_id); + int filter_h = get_arg_value(op, ++arg_id); + int filter_w = get_arg_value(op, ++arg_id); + int filter_c = get_arg_value(op, ++arg_id); + int bias_n = get_arg_value(op, ++arg_id); + int bias_h = get_arg_value(op, ++arg_id); + int bias_w = get_arg_value(op, ++arg_id); + int bias_c = get_arg_value(op, ++arg_id); + int output_n = get_arg_value(op, ++arg_id); + int output_h = get_arg_value(op, ++arg_id); + int output_w = get_arg_value(op, ++arg_id); + int output_c = get_arg_value(op, ++arg_id); + + std::string context = EmitCMSISNNContext(stream, context_buffer_name, context_buffer_size); + std::string conv_params = + EmitCMSISNNConvParams(stream, input_offset, output_offset, stride_w, stride_h, padding_w, + padding_h, dilation_w, dilation_h, clip_min, clip_max); + std::string quant_params = EmitCMSISNNPerChannelQuantParams(stream, multiplier, shift); + std::string input_dim = EmitCMSISNNDims(stream, "input", input_n, input_h, input_w, input_c); + std::string filter_dim = + EmitCMSISNNDims(stream, "filter", filter_n, filter_h, filter_w, filter_c); + std::string bias_dim = EmitCMSISNNDims(stream, "bias", bias_n, bias_h, bias_w, bias_c); + std::string output_dim = + EmitCMSISNNDims(stream, "output", output_n, output_h, output_w, output_c); + + PrintIndent(); + stream << "arm_status status = "; + stream << cmsis_func_name << "("; + stream << "&" << context << ", "; + stream << "&" << conv_params << ", "; + stream << "&" << quant_params << ", "; + stream << "&" << input_dim << ", " << input_data << ", "; + stream << "&" << filter_dim << ", " << filter_data << ", "; + stream << "&" << bias_dim << ", " << bias_data << ", "; + stream << "&" << output_dim << ", " << output_data << ");\n"; + PrintIndent(); + stream << "if (status != ARM_MATH_SUCCESS) {\n"; + PrintIndent(); + PrintIndent(); + stream << "return -1;\n"; + PrintIndent(); + stream << "}\n"; + } + + private: + std::string context_buffer_name_ = "NULL"; + int context_buffer_size_ = 0; }; runtime::Module TIRToRuntime(IRModule mod, Target target) { diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 1456f7e68ed3..051b325e3db3 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -225,7 +225,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 328e970ab093..85a03cb5bd16 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -615,6 +615,20 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Already lowered by other means so we don't need to mutate // the call but we do need to mutate the arguments if (prim_func->IsInstance()) { + // Function should already be Target annotated by this point + // but the TE Compiler metadata is still needed for the callback + // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks + GlobalVar prim_func_var = Downcast(call_node->op); + tir::PrimFunc downcast_prim_func = Downcast(prim_func); + + Map prim_fns = {{prim_func_var, downcast_prim_func}}; + tir::PrimFunc func_with_metadata = + WithAttrs(downcast_prim_func, { + {"prim_fn_var", prim_func_var}, + {"prim_funcs", prim_fns}, + }); + + this->process_fn_(func_with_metadata); return Call(call_node->op, visited_args, call_node->attrs); } @@ -682,8 +696,7 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { } } -Pass LowerTensorExpr(const String& module_name, TECompiler compiler, - std::function process_fn) { +Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler); @@ -831,13 +844,13 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa /*! * \brief A function to create the function metadata for an input function (ie calculate buffer * input/output sizes) - * \param relay_func The function to calculate function metadata for + * \param func The function to calculate function metadata for * \param function_metadata The map that stores all the function metadatas */ -void UpdateFunctionMetadata(Function relay_func, +void UpdateFunctionMetadata(BaseFunc func, Map& function_metadata) { // NOLINT(*) VLOG_CONTEXT << "UpdateFunctionMetadata"; - VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(relay_func); + VLOG(1) << "updating function metadata for:" << std::endl << PrettyPrint(func); // Originally UpdateFunctionMetadata took in CCachedFunc and looped through all the funcs stored // there Now the goal is to take only one func because process_fn should be controlling the // iteration However, to do the workspace calculations we need the primfuncs. So process_fn @@ -852,13 +865,13 @@ void UpdateFunctionMetadata(Function relay_func, Map relay_primfuncs; Optional> prim_fns = - relay_func->GetAttr>("prim_funcs"); + func->GetAttr>("prim_funcs"); CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler."; - Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); + Optional prim_fn_var = func->GetAttr("prim_fn_var"); CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; - Optional relay_target = relay_func->GetAttr(tvm::attr::kTarget); + Optional relay_target = func->GetAttr(tvm::attr::kTarget); CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; for (const auto& kv : prim_fns.value()) { @@ -883,6 +896,12 @@ void UpdateFunctionMetadata(Function relay_func, // Calculating size for I/O // TODO(mbs): See also the other three utils for calculating tensor bytesize. for (auto const& param : prim_fn->params) { + bool not_a_buffer = prim_fn->buffer_map.count(param) == 0; + if (not_a_buffer) { + io_sizes.Set(prim_fn_target, 0); + continue; + } + auto p_shape = prim_fn->buffer_map[param]->shape; int num_of_elements = 1; for (const auto& dim_index_expr : p_shape) { @@ -899,7 +918,9 @@ void UpdateFunctionMetadata(Function relay_func, constant_sizes.Set(prim_fn_target, 0); tir_primfuncs.Set(prim_fn_target, prim_fn); - relay_primfuncs.Set(prim_fn_target, relay_func); + if (func->IsInstance()) { + relay_primfuncs.Set(prim_fn_target, Downcast(func)); + } } backend::FunctionInfo fi = backend::FunctionInfo( @@ -913,8 +934,7 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -IRModule LowerTE(const IRModule& module, const String& module_name, - std::function process_fn) { +IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn) { TECompiler compiler; auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module); @@ -966,7 +986,7 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(const String& module_name, std::function process_fn) { +Pass LowerTEPass(const String& module_name, ProcessFn process_fn) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { return LowerTE(module, module_name, process_fn); }; diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index b5d5b508f6be..268d1a65a31b 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -63,7 +63,7 @@ namespace tec { using TargetMap = std::unordered_map; using DeviceMap = std::unordered_map; -using ProcessFn = std::function; +using ProcessFn = std::function; /*! * \brief A compiler which lowers primitive Relay functions to tensor expressions @@ -140,10 +140,10 @@ class TECompiler : public ObjectRef { /*! * \brief A function to create the function metadata for an input function (ie calculate buffer * input/output sizes) - * \param relay_func The function to calculate function metadata for + * \param func The function to calculate function metadata for * \param function_metadata The map that stores all the function metadatas */ -void UpdateFunctionMetadata(Function relay_func, +void UpdateFunctionMetadata(BaseFunc func, Map& function_metadata); // NOLINT(*) /*! @@ -188,7 +188,7 @@ Map GetPerTargetModules(IRModule mod); */ IRModule LowerTE( const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name, - ProcessFn process_fn = [](Function f) {}); + ProcessFn process_fn = [](BaseFunc f) {}); /*! \brief Pass to lower an IRModule's primitive functions to TIR. * @@ -201,7 +201,7 @@ IRModule LowerTE( * each function that we lower * \returns The pass which lowers primitive functions to TIR */ -transform::Pass LowerTEPass(const String& module_name, std::function process_fn); +transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 4224a99c2628..f22e9f4318e5 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -190,7 +190,7 @@ struct ConstantUpdater : public ExprVisitor { * \param func The function from which to get the constant params. * \param params The params to update with the constants. */ -inline void UpdateConstants(Function func, +inline void UpdateConstants(BaseFunc func, std::unordered_map* params) { VLOG_CONTEXT << "UpdateConstants"; VLOG(1) << "updating constants for:" << std::endl << PrettyPrint(func); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b6ecd7a4b7ca..490d6893964d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1151,8 +1151,7 @@ void VMCompiler::Codegen() { // Collect metadata in functions that are handled by external codegen. auto name = cfunc->prim_fn_var->name_hint; ICHECK(mod->ContainGlobalVar(name)); - Function func = Downcast(mod->Lookup(name)); - backend::UpdateConstants(func, ¶ms_); + backend::UpdateConstants(mod->Lookup(name), ¶ms_); } else if (funcs.count(target) == 0) { funcs.Set(target, mod); } else { diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 4ff1c6ef61ed..d72e2b37ee8a 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -54,7 +54,7 @@ class CodeGenCHost : public CodeGenC { // overload visitor functions void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os); // NOLINT(*) // overload min and max to use the ternary operator, so we don't rely on the // standard library implementations void VisitExpr_(const MinNode* op, std::ostream& os) final; // NOLINT(*) diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index 42eb31a3532c..39b8c5fd21c6 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -131,6 +131,7 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z ) +@skip_if_no_reference_system @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) @pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]]) diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py new file mode 100644 index 000000000000..243197e4eb3e --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -0,0 +1,316 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: Conv2D""" +import itertools +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + + +from tests.python.relay.aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + AOT_DEFAULT_RUNNER, + generate_ref_data, + compile_and_run, +) +from utils import ( + skip_if_no_reference_system, + make_module, + count_num_calls, + get_range_for_dtype_str, + get_same_padding, + get_conv2d_qnn_params, + make_qnn_relu, +) + + +def make_model( + shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + kernel_dtype, + out_channels, + weight_format, + enable_bias, + relu_type, +): + """Return a model and any parameters it may have""" + h_index = weight_format.index("H") + w_index = weight_format.index("W") + kernel_h = kernel_shape[h_index] + kernel_w = kernel_shape[w_index] + a = relay.var("input", shape=shape, dtype=dtype) + p = (0, 0, 0, 0) + if padding == "SAME": + p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) + a = relay.nn.pad( + a, + pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)], + pad_value=input_zero_point, + pad_mode="constant", + ) + shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) + + weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) + rng = np.random.default_rng(12321) + w = tvm.nd.array( + rng.integers( + np.iinfo(kernel_dtype).min, + high=np.iinfo(kernel_dtype).max, + size=weight_shape, + dtype=kernel_dtype, + ) + ) + weight_const = relay.const(w, kernel_dtype) + conv = relay.qnn.op.conv2d( + a, + weight_const, + input_zero_point=relay.const(input_zero_point, "int32"), + kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_scale=relay.const(input_scale, "float32"), + kernel_scale=relay.const(kernel_scale, "float32"), + kernel_size=(kernel_h, kernel_w), + data_layout="NHWC", + kernel_layout=weight_format, + dilation=dilation, + strides=strides, + groups=groups, + channels=out_channels, + padding=p, + out_dtype="int32", + ) + b = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) + bias_const = relay.const(b, "int32") + last_op = relay.nn.bias_add(conv, bias_const, axis=3) if enable_bias else conv + requant_input_sc = [sc * input_scale for sc in kernel_scale] + last_op = relay.qnn.op.requantize( + last_op, + relay.const(requant_input_sc, "float32"), + relay.const(0, "int32"), + relay.const(output_scale, "float32"), + relay.const(output_zero_point, "int32"), + out_dtype=dtype, + ) + last_op = make_qnn_relu(last_op, relu_type, output_scale, output_zero_point, dtype) + params = {"w": w, "b": b} + return last_op, params + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)]) +@pytest.mark.parametrize("kernel_size", [(3, 3)]) +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1)), ((1, 1), (1, 1))]) +@pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize("relu_type", ["NONE", "RELU"]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale, out_channels", + [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], +) +def test_op_int8( + ifm_shape, + kernel_size, + padding, + strides, + dilation, + enable_bias, + relu_type, + input_zero_point, + input_scale, + kernel_scale, + out_channels, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + kernel_zero_point = 0 + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + dtype = "int8" + in_min, in_max = get_range_for_dtype_str(dtype) + + weight_shape = None + if weight_format == "HWIO": + weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + else: + weight_shape = (kernel_h, kernel_w, ifm_shape[3], out_channels) + + output_scale, output_zero_point = get_conv2d_qnn_params( + weight_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + False, + ) + + model, params = make_model( + ifm_shape, + weight_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsis-nn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +def parameterize_for_invalid_model(test): + in_dtype = ["uint8", "int8"] + kernel_dtype = ["uint8", "int8"] + kernel_zero_point = [-33, 10, 0] + all_combinations = itertools.product(in_dtype, kernel_dtype, kernel_zero_point) + all_combinations = filter( + lambda parameters: not ( + parameters[0] == "int8" and parameters[1] == "int8" and parameters[2] == 0 + ), + all_combinations, + ) + return pytest.mark.parametrize( + ["in_dtype", "kernel_dtype", "kernel_zero_point"], + all_combinations, + )(test) + + +@tvm.testing.requires_cmsisnn +@parameterize_for_invalid_model +def test_invalid_parameters( + in_dtype, + kernel_dtype, + kernel_zero_point, +): + ifm_shape = (1, 28, 28, 12) + out_channels = 2 + input_scale = 1 + input_zero_point = 24 + kernel_scale = [0.11, 0.0237] + in_min, in_max = get_range_for_dtype_str(in_dtype) + + kernel_layout = "HWIO" + kernel_shape = [3, 3, ifm_shape[3], out_channels] + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + in_dtype, + kernel_dtype, + in_dtype, + False, + ) + model, params = make_model( + shape=ifm_shape, + kernel_shape=kernel_shape, + input_zero_point=input_zero_point, + input_scale=input_scale, + kernel_zero_point=kernel_zero_point, + kernel_scale=kernel_scale, + output_zero_point=output_zero_point, + output_scale=output_scale, + padding="SAME", + strides=(1, 1), + dilation=(1, 1), + groups=1, + dtype=in_dtype, + kernel_dtype=kernel_dtype, + out_channels=out_channels, + weight_format=kernel_layout, + enable_bias=True, + relu_type="NONE", + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_cmsisnn/test_extract_constants.py b/tests/python/contrib/test_cmsisnn/test_extract_constants.py new file mode 100644 index 000000000000..7305240110e8 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_extract_constants.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: extract_constants pass""" +import itertools +import math +import numpy as np +import pytest +import tvm +from tvm import relay + +from utils import ( + make_module, + count_num_calls, + get_range_for_dtype_str, + get_same_padding, + get_conv2d_qnn_params, + make_qnn_relu, +) + +tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) + + +class CheckFunctionsForConstants(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.num_constants_ = 0 + + def visit_call(self, call): + super().visit_call(call) + for arg in call.args: + if isinstance(arg, relay.Constant) and arg.data.numpy().ndim > 0: + self.num_constants_ += 1 + + def visit_function(self, func): + super().visit_function(func) + assert self.num_constants_ == 0, "Functions should not have constant arguments in Calls" + + +def set_external_func_attr(func, compiler, ext_symbol): + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", compiler) + func = func.with_attr("global_symbol", ext_symbol) + return func + + +@tvm.testing.requires_cmsisnn +def test_external_function(): + y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + x0 = relay.var("x0", shape=(8, 8)) + y0_const = relay.const(y0_data, "float32") + z0 = x0 + y0_const + ef = relay.Function([x0], z0, relay.TensorType((8, 8), "float32")) + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "external_compiler", ev.name_hint) + + x = relay.var("x", shape=(8, 8)) + c = relay.Call(ev, [x]) + mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = ExtractConstantsFromPartitionedFunction()(mod) + CheckFunctionsForConstants().visit_function(mod[ev]) + relay.transform.InferType()(mod) + + +@tvm.testing.requires_cmsisnn +def test_nested_function(): + y1_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + x1 = relay.var("x1", shape=(8, 8)) + y1_const = relay.const(y1_data, "float32") + z1 = x1 + y1_const + w1 = z1 * relay.const(5.0, "float32") + lf = relay.Function([x1], w1, relay.TensorType((8, 8), "float32")) + + x0 = relay.var("x0", shape=(8, 8)) + c0 = relay.Call(lf, [x0]) + ef = relay.Function([x0], c0, relay.TensorType((8, 8), "float32")) + + x = relay.var("x", shape=(8, 8)) + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "external_compiler", ev.name_hint) + c = relay.Call(ev, [x]) + mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = ExtractConstantsFromPartitionedFunction()(mod) + CheckFunctionsForConstants().visit_function(mod[ev]) + relay.transform.InferType()(mod) + + +@tvm.testing.requires_cmsisnn +def test_multiple_functions(): + y20_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + x20 = relay.var("x20", shape=(8, 8)) + y20_const = relay.const(y20_data, "float32") + z20 = x20 + y20_const + f20 = relay.Function([x20], z20, relay.TensorType((8, 8), "float32")) + + y21_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + x21 = relay.var("x21", shape=(8, 8)) + y21_const = relay.const(y21_data, "float32") + z21 = x21 + y21_const + f21 = relay.Function([x21], z21, relay.TensorType((8, 8), "float32")) + + x10 = relay.var("x10", shape=(8, 8)) + c10 = relay.Call(f20, [x10]) + c11 = relay.Call(f21, [c10]) + ef = relay.Function([x10], c11, relay.TensorType((8, 8), "float32")) + + x0 = relay.var("x0", shape=(8, 8)) + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "external_compiler", ev.name_hint) + c = relay.Call(ev, [x0]) + mf = relay.Function([x0], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = ExtractConstantsFromPartitionedFunction()(mod) + CheckFunctionsForConstants().visit_function(mod[ev]) + relay.transform.InferType()(mod) + + +@tvm.testing.requires_cmsisnn +def test_main_function(): + x0 = relay.var("x0", shape=(8, 8)) + y0 = relay.var("y0", shape=(8, 8)) + z0 = x0 + y0 + ef = relay.Function([x0, y0], z0, relay.TensorType((8, 8), "float32")) + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "external_compiler", ev.name_hint) + + x = relay.var("x", shape=(8, 8)) + y_data = np.random.uniform(0, 1, (8, 8)).astype("float32") + y_const = relay.const(y_data, "float32") + z = x + y_const + c = relay.Call(ev, [x, z]) + mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = ExtractConstantsFromPartitionedFunction()(mod) + check_for_constants = CheckFunctionsForConstants() + check_for_constants.visit_call(mod[mv].body) + assert ( + check_for_constants.num_constants_ == 1 + ), "main() should have same number of arguments as before" + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_cmsisnn/test_generate_constants.py b/tests/python/contrib/test_cmsisnn/test_generate_constants.py new file mode 100644 index 000000000000..c5e97253d94b --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_generate_constants.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: generate_constants pass""" +import itertools +import math +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + +from utils import ( + make_module, + count_num_calls, + get_range_for_dtype_str, + get_same_padding, + get_conv2d_qnn_params, + make_qnn_relu, +) + +tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) + + +def quantize_scale(scale): + multiplier, shift = math.frexp(scale) + multiplier_q31 = round(multiplier * (1 << 31)) + return multiplier_q31, shift + + +class CheckGeneratedConstants(tvm.relay.ExprVisitor): + def __init__(self, enable_bias, multiplier, shift): + super().__init__() + self.num_constant_args_ = 0 + self.enable_bias_ = enable_bias + self.multiplier_ = multiplier + self.shift_ = shift + + def visit_call(self, call): + super().visit_call(call) + if isinstance(call.op, tvm.ir.expr.GlobalVar): + # extern_fn_call(input, weight, multiplier, weight_scale, bias_optional, input_scale, shift) + multiplier = call.args[2] + shift = call.args[6] if self.enable_bias_ else call.args[5] + assert isinstance( + multiplier, relay.expr.Constant + ), "Expected quantized multiplier at argument#3" + assert isinstance( + shift, relay.expr.Constant + ), "Expected a constant while looking for quantized shift" + multiplier = multiplier.data.numpy() + shift = shift.data.numpy() + tvm.testing.assert_allclose(multiplier, self.multiplier_, atol=100, rtol=1e-10) + tvm.testing.assert_allclose(shift, self.shift_, atol=1, rtol=1e-5) + + +def make_model( + shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + kernel_dtype, + out_channels, + weight_format, + enable_bias, + relu_type, +): + """Return a model and any parameters it may have""" + h_index = weight_format.index("H") + w_index = weight_format.index("W") + kernel_h = kernel_shape[h_index] + kernel_w = kernel_shape[w_index] + a = relay.var("input", shape=shape, dtype=dtype) + p = (0, 0, 0, 0) + if padding == "SAME": + p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), dilation, strides) + a = relay.nn.pad( + a, + pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)], + pad_value=input_zero_point, + pad_mode="constant", + ) + shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], shape[3]) + + weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels) + rng = np.random.default_rng(12321) + w = tvm.nd.array( + rng.integers( + np.iinfo(kernel_dtype).min, + high=np.iinfo(kernel_dtype).max, + size=weight_shape, + dtype=kernel_dtype, + ) + ) + weight_const = relay.const(w, kernel_dtype) + conv = relay.qnn.op.conv2d( + a, + weight_const, + input_zero_point=relay.const(input_zero_point, "int32"), + kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_scale=relay.const(input_scale, "float32"), + kernel_scale=relay.const(kernel_scale, "float32"), + kernel_size=(kernel_h, kernel_w), + data_layout="NHWC", + kernel_layout=weight_format, + dilation=dilation, + strides=strides, + groups=groups, + channels=out_channels, + padding=p, + out_dtype="int32", + ) + b = tvm.nd.array(rng.integers(0, high=10, size=(out_channels,), dtype="int32")) + bias_const = relay.const(b, "int32") + last_op = relay.nn.bias_add(conv, bias_const, axis=3) if enable_bias else conv + requant_input_sc = [sc * input_scale for sc in kernel_scale] + last_op = relay.qnn.op.requantize( + last_op, + relay.const(requant_input_sc, "float32"), + relay.const(0, "int32"), + relay.const(output_scale, "float32"), + relay.const(output_zero_point, "int32"), + out_dtype=dtype, + ) + last_op = make_qnn_relu(last_op, relu_type, output_scale, output_zero_point, dtype) + params = {"w": w, "b": b} + return last_op, params + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale, out_channels", + [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], +) +def test_op_int8( + enable_bias, + input_zero_point, + input_scale, + kernel_scale, + out_channels, +): + ifm_shape = (1, 28, 28, 3) + padding = "VALID" + strides = (1, 1) + dilation = (1, 1) + kernel_size = (3, 3) + kernel_zero_point = 0 + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + dtype = "int8" + relu_type = "RELU" + in_min, in_max = get_range_for_dtype_str(dtype) + + weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + + output_scale, output_zero_point = get_conv2d_qnn_params( + weight_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + False, + ) + + model, params = make_model( + ifm_shape, + weight_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(mod, params) + multiplier_array = [] + shift_array = [] + for i in range(out_channels): + multiplier, shift = quantize_scale(input_scale * kernel_scale[i] / output_scale) + multiplier_array.append(multiplier) + shift_array.append(shift) + CheckGeneratedConstants(enable_bias, multiplier_array, shift_array).visit_function( + cmsisnn_mod["main"] + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py index e78b06cbab84..a099e8102d8b 100644 --- a/tests/python/contrib/test_cmsisnn/test_networks.py +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -87,8 +87,8 @@ def test_cnn_small(): tflite_model_buf = f.read() input_shape = (1, 490) - in_min, in_max = get_range_for_dtype_str("int8") - input_data = np.random.randint(in_min, high=in_max, size=input_shape).astype(np.float32) + rng = np.random.default_rng(12345) + input_data = rng.random(input_shape, dtype=np.float32) orig_mod, params = convert_to_relay(tflite_model_buf, input_data, "input") cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) @@ -101,7 +101,13 @@ def test_cnn_small(): params = {} output_list = generate_ref_data(orig_mod["main"], inputs, params) compile_and_run( - AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params), + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), test_runner, interface_api, use_unpacked_api, diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index 40e12fc962b2..5d1c2fdcc8c1 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""CMSIS-NN integration tests: softmax""" +"""CMSIS-NN integration tests: Softmax""" import sys import itertools @@ -64,7 +64,7 @@ def make_model( @skip_if_no_reference_system @pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]]) @tvm.testing.requires_cmsisnn -def test_softmax_int8(zero_point, scale): +def test_op_int8(zero_point, scale): interface_api = "c" use_unpacked_api = True test_runner = AOT_CORSTONE300_RUNNER @@ -135,7 +135,7 @@ def parameterize_for_invalid_model(test): @parameterize_for_invalid_model @tvm.testing.requires_cmsisnn -def test_invalid_softmax(in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale): +def test_invalid_parameters(in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale): model = make_model( [1, 16, 16, 3], in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale ) diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py index 3fd12efd3367..145dbf4b499c 100644 --- a/tests/python/contrib/test_cmsisnn/utils.py +++ b/tests/python/contrib/test_cmsisnn/utils.py @@ -19,8 +19,10 @@ import platform +import math import numpy as np import pytest +from typing import List, Dict, Optional, Any, Union, Tuple import tvm from tvm import relay @@ -33,7 +35,7 @@ def skip_if_no_reference_system(func): def count_num_calls(mod): - """Count number of CallNode in the IRModule""" + """Counts number of CallNode(s) in the IRModule""" class CallCounter(relay.ExprVisitor): def __init__(self): @@ -54,7 +56,7 @@ def visit_call(self, call): def get_range_for_dtype_str(dtype): """ - Produce the min,max for a give data type. + Produces the min,max for a give data type. Parameters ---------- @@ -77,7 +79,124 @@ def get_range_for_dtype_str(dtype): def make_module(func): - """Create IRModule from Function""" + """Creates IRModule from Function""" func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) - return relay.transform.InferType()(mod) + mod = relay.transform.InferType()(mod) + return mod + + +def get_same_padding(data, kernel, dilation, stride, cmsisnn_padding=True): + """Provides CMSIS-NN padding when output dim == input dim""" + dilated_kernel_h = dilation[0] * (kernel[0] - 1) + 1 + dilated_kernel_w = dilation[1] * (kernel[1] - 1) + 1 + out = int(math.ceil(float(data[0]) / float(stride[0]))) + pad = max(0, (out - 1) * stride[0] + dilated_kernel_h - data[0]) + pad_top, pad_bottom = (pad, 0) if cmsisnn_padding else (0, pad) + + out = int(math.ceil(float(data[1]) / float(stride[1]))) + pad = max(0, (out - 1) * stride[1] + dilated_kernel_w - data[1]) + pad_left, pad_right = (pad, 0) if cmsisnn_padding else (0, pad) + return [pad_top, pad_left, pad_bottom, pad_right] + + +def get_conv2d_qnn_params( + weight_shape: List[int], + input_scale: float, + input_zp: int, + weights_scale: Union[float, List[float]], + weights_zp: int, + input_dtype: str = "int8", + weights_dtype: str = "int8", + output_dtype: str = "int8", + is_depthwise: bool = False, +) -> Tuple[float, int]: + """ + Calculate the output quantization parameters for convolution based on the input and + weights quantization paramters and the data types. + + Parameters + ---------- + weight_shape : List[int] + shape of the weights + input_scale : float + scale of the input tensor + input_zp : int + zero point of the input tensor + weights_scale : Union[float, List[float]] + scale(s) of the weights tensor + weights_zp : int + zero point of the weights tensor + is_depthwise : bool + whether it is a depthwise convolution + input_dtype : str + data type of the input tensor + weights_dtype : str + data type of the weights tensor + output_dtype : str + data type of the output tensor + + Returns + ------- + output_scale : float + scale of the output tensor + output_zp : int + zero point of the output tensor + """ + input_dtype_min, input_dtype_max = get_range_for_dtype_str(input_dtype) + input_max = input_scale * (input_dtype_max - input_zp) + input_min = input_scale * (input_dtype_min - input_zp) + + weights_dtype_min, weights_dtype_max = get_range_for_dtype_str(weights_dtype) + weights_sc_max = np.max(weights_scale) + weights_max = weights_sc_max * (weights_dtype_max - weights_zp) + + weights_sc_min = np.min(weights_scale) + weights_min = weights_sc_min * (weights_dtype_min - weights_zp) + + weights_h = weight_shape[1] + weights_w = weight_shape[2] + channels = weight_shape[3] + num_elements = weights_h * weights_w * channels + # Adjust the result if it is a depthwise convolution + if is_depthwise: + num_elements = num_elements / channels + + # The smallest and largest possible values in the unquantized output tensor + output_limits = [ + weights_max * input_max * num_elements, + weights_min * input_max * num_elements, + weights_min * input_min * num_elements, + weights_max * input_min * num_elements, + ] + + output_max = max(output_limits) + output_min = min(output_limits) + output_dtype_min, output_dtype_max = get_range_for_dtype_str(output_dtype) + + output_scale = (output_max - output_min) / (output_dtype_max - output_dtype_min) + output_zp = int(output_dtype_min - (output_min / output_scale)) + + return output_scale, output_zp + + +def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype): + """Mimics convert_qnn_fused_activation_function from TFLite frontend""" + quantize = lambda x: float(int(round(x / scale)) + zero_point) + + # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not + # beyond the dtype range. + qmin, qmax = get_range_for_dtype_str(dtype) + + # The input expr is a quantized tensor with its scale and zero point. We calculate the + # suitable clip off points based on these scale and zero point. + if fused_activation_fn == "NONE": + return expr + if fused_activation_fn == "RELU6": + return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0)), a_max=min(qmax, quantize(6.0))) + if fused_activation_fn == "RELU_N1_TO_1": + return tvm.relay.op.clip( + expr, a_min=max(qmin, quantize(-1.0)), a_max=min(qmax, quantize(1.0)) + ) + if fused_activation_fn == "RELU": + return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax)