diff --git a/CMakeLists.txt b/CMakeLists.txt index a094beb34f6e..5c01b2caad2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,6 +48,7 @@ tvm_option(USE_TF_COMPILE_FLAGS "Build with TensorFlow's compile flags." OFF) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF) +tvm_option(USE_CMSISNN "Build with Arm CMSIS-NN" OFF) tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON) tvm_option(USE_LIBBACKTRACE "Build libbacktrace to supply linenumbers on stack traces" AUTO) tvm_option(BUILD_STATIC_RUNTIME "Build static version of libtvm_runtime" OFF) @@ -418,6 +419,7 @@ include(cmake/modules/ROCM.cmake) include(cmake/modules/LLVM.cmake) include(cmake/modules/Micro.cmake) include(cmake/modules/contrib/EthosN.cmake) +include(cmake/modules/contrib/CMSISNN.cmake) include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/CODEGENC.cmake) include(cmake/modules/contrib/DNNL.cmake) diff --git a/cmake/modules/contrib/CMSISNN.cmake b/cmake/modules/contrib/CMSISNN.cmake new file mode 100644 index 000000000000..4bd0fd865dc0 --- /dev/null +++ b/cmake/modules/contrib/CMSISNN.cmake @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_CMSISNN) + message(STATUS "Build with CMSIS-NN support") + file(GLOB RELAY_CONTRIB_CMSISNN_SRCS src/relay/backend/contrib/cmsisnn/*.cc) + list(APPEND COMPILER_SRCS ${RELAY_CONTRIB_CMSISNN_SRCS}) +endif(USE_CMSISNN) diff --git a/python/tvm/driver/tvmc/composite_target.py b/python/tvm/driver/tvmc/composite_target.py index ac1a41a0c4a9..b1f00b7d1dde 100644 --- a/python/tvm/driver/tvmc/composite_target.py +++ b/python/tvm/driver/tvmc/composite_target.py @@ -24,6 +24,7 @@ from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib from tvm.relay.op.contrib.ethosn import partition_for_ethosn +from tvm.relay.op.contrib.cmsisnn import partition_for_cmsisnn from tvm.relay.op.contrib.bnns import partition_for_bnns from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai @@ -49,6 +50,10 @@ "config_key": None, "pass_pipeline": partition_for_arm_compute_lib, }, + "cmsis-nn": { + "config_key": None, + "pass_pipeline": partition_for_cmsisnn, + }, "ethos-n77": { "config_key": "relay.ext.ethos-n.options", "pass_pipeline": partition_for_ethosn, diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index 4fc2b63748db..b84e215fa581 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -16,3 +16,4 @@ # under the License. """Backend codegen modules for relay.""" from . import compile_engine +from .contrib import cmsisnn diff --git a/python/tvm/relay/backend/contrib/__init__.py b/python/tvm/relay/backend/contrib/__init__.py new file mode 100644 index 000000000000..16b83612d797 --- /dev/null +++ b/python/tvm/relay/backend/contrib/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""external backend codegen modules for relay.""" +from . import cmsisnn diff --git a/python/tvm/relay/backend/contrib/cmsisnn/__init__.py b/python/tvm/relay/backend/contrib/cmsisnn/__init__.py new file mode 100644 index 000000000000..cc6873f9fda6 --- /dev/null +++ b/python/tvm/relay/backend/contrib/cmsisnn/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CMSIS-NN codegen modules for relay.""" +from . import codegen diff --git a/python/tvm/relay/backend/contrib/cmsisnn/codegen.py b/python/tvm/relay/backend/contrib/cmsisnn/codegen.py new file mode 100644 index 000000000000..ef08f5eb317d --- /dev/null +++ b/python/tvm/relay/backend/contrib/cmsisnn/codegen.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Codegen for CMSIS-NN""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor + + +class GenerateTIR(ExprVisitor): + """Generates TIR module containing TIR primfuncs corresponding to the Relay operators. + Note: Relay operator to primfunc mapping may not be 1:1. + """ + + def __init__(self, name): + super().__init__() + self.name = name + self.tir_mod = None + self.scale = 1.0 / 256 + + def call_contains_op(self, call, op_name): + if not isinstance(call.op, tvm.ir.op.Op): + return False + if call.op.name != op_name: + return False + return True + + def is_quantized_softmax(self, call): + """Checks for the following relay sequence + a = qnn.dequantize(in, scale, zero_point) + b = nn.softmax(a) + c = qnn.quantize(c, scale, zero_point) + """ + if not self.call_contains_op(call, "qnn.quantize"): + return False + softmax_call = call.args[0] + if not self.call_contains_op(softmax_call, "nn.softmax"): + return False + dequantize_call = softmax_call.args[0] + if not self.call_contains_op(dequantize_call, "qnn.dequantize"): + return False + self.scale = dequantize_call.args[1].data.numpy().item(0) + return True + + def emit_softmax_tir(self, call): + """Generates TIR extern_call for softmax""" + shape = call.checked_type.shape # NHWC + dtype = call.checked_type.dtype + ir_builder = tvm.tir.ir_builder.create() + in_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype) + out_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype) + + trailing_dim = len(shape) - 1 + num_rows = 1 + for dim in range(trailing_dim): + num_rows *= shape[dim] + row_size = shape[trailing_dim] + ir_builder.emit( + tvm.tir.call_extern( + dtype, + "arm_softmax_s8", + in_buf.data, + num_rows, + row_size, + self.scale, + out_buf.data, + ) + ) + prim_func = tvm.tir.PrimFunc([in_buf, out_buf], ir_builder.get()) + prim_func = prim_func.with_attr("global_symbol", self.name) + prim_func = prim_func.with_attr("tir.noalias", True) + self.tir_mod = tvm.IRModule({self.name: prim_func}) + + def visit_call(self, call): + """Iterates over the relay operators within relay external function""" + super().visit_call(call) + if self.is_quantized_softmax(call): + self.emit_softmax_tir(call) + + def generate_tir(self, func): + self.visit(func) + return self.tir_mod + + +def relay_to_tir(name, func): + """Lower a Relay function to TIR for the CMSIS-NN target. + + The Relay function should only contain operations supported + by the CMSIS-NN target. This is enforced by the graph partitioner + for CMSIS-NN. + + Parameters + ---------- + name: str + Name of the external relay function + func : tvm.relay.Function + The Relay function to lower. + + Returns + ------- + mod : tvm.IRModule + The lowered TIR module. + + """ + return GenerateTIR(name).generate_tir(func) + + +@tvm.register_func("relay.ext.cmsisnn") +def cmsisnn_compiler(relay_func): + """It compiles Relay's external function into equivalent TIR + and subsequently converts that into 'c' code. During the 'c' + code generation, it embeds CMSIS-NN APIs for the corresponding + operators. + """ + mod = tvm.IRModule() + mod["main"] = relay_func + mod = relay.transform.InferType()(mod) + func_name = relay_func.attrs["global_symbol"] + tir_mod = relay_to_tir(func_name, mod["main"]) + cmsisnn_runtime = tvm._ffi.get_global_func("runtime.module.cmsisnn.create") + return cmsisnn_runtime(tir_mod) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index daf1e098d7f1..f1153c6a8575 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -68,11 +68,15 @@ def softmax_pattern(): def check_quantized_softmax(extract): """Check if softmax is supported by CMSIS-NN.""" + dequantize_call = extract.args[0].args[0] + scale = extract.args[1].data.numpy().item(0) + zero_point = extract.args[2].data.numpy().item(0) # check for dtypes of quantize and dequantize return ( - extract.attrs.out_dtype == "int8" - and extract.args[0].args[0].args[0].checked_type.dtype == "int8" + (scale == 1.0 / 256 and zero_point == -128) + and extract.attrs.out_dtype == "int8" + and dequantize_call.args[0].checked_type.dtype == "int8" ) return [ diff --git a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc b/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc new file mode 100644 index 000000000000..d2e498a52ce4 --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include + +#include "../../../../runtime/file_utils.h" +#include "../../../../target/source/codegen_c.h" +#include "../../../qnn/utils.h" + +namespace tvm { +namespace runtime { + +using namespace tir; + +class CodeGenCMSISNN : public tvm::codegen::CodeGenC { + public: + void Init(bool output_ssa) { + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + CodeGenC::Init(output_ssa); + } + + /*! + * \brief Emit code that offloads a subgraph to the Cortex-M + * + * \return string of code that offloads a subgraph to the Cortex-M + */ + void AddFunction(const PrimFunc& prim_func) { + PrintExternCPrefix(stream); + CodeGenC::AddFunction(prim_func); + PrintExternCPostfix(stream); + } + + private: + void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + if (!op->op.same_as(builtin::call_extern())) { + return; + } + std::string cmsis_func_name = op->args[0].as()->value; + if (cmsis_func_name == "arm_softmax_s8") { + EmitSoftmax(op); + } + return; + } + + /*! * \brief Creates a cplusplus guard prefix for extern "C" printing */ + void PrintExternCPrefix(std::ostringstream& ss) { + PrintIndent(); + ss << "#ifdef __cplusplus\n"; + ss << "extern \"C\" {\n"; + ss << "#endif\n"; + } + + /*! * \brief Creates a cplusplus guard postfix for extern "C" printing */ + void PrintExternCPostfix(std::ostringstream& ss) { + PrintIndent(); + ss << "#ifdef __cplusplus\n"; + ss << "}\n"; + ss << "#endif\n"; + } + + /*! * \brief Emits CMSIS-NN code block for softmax */ + void EmitSoftmax(const CallNode* op) { + // @tir.call_extern("arm_softmax_s8", buffer_0, num_rows, row_size, scale, buffer_1, dtype=int8) + std::string cmsis_func_name = op->args[0].as()->value; + int32_t num_rows = op->args[2].as()->value; + int32_t row_size = op->args[3].as()->value; + float quant_scale = op->args[4].as()->value; + + // calculate multiplier and shift for CMSIS-NN softmax API + // Note: tfl micro assumptions + // TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); + // TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); + // TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); + double beta = 1.0; + int32_t input_bits = 5; + double beta_multiplier = (beta * quant_scale * (1 << (31 - input_bits))); + beta_multiplier = std::min(beta_multiplier, (1ll << 31) - 1.0); + auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier); + int32_t mult = std::get<0>(mult_shift_pair); + int32_t shift = std::get<1>(mult_shift_pair); + int32_t diff_min = (1 << 5) - 1; + diff_min <<= (31 - 5); + diff_min >>= shift; + diff_min *= -1; + + PrintIndent(); + stream << "int32_t num_rows = " << num_rows << ";\n"; + PrintIndent(); + stream << "int32_t row_size = " << row_size << ";\n"; + PrintIndent(); + stream << "int32_t mult = " << mult << ";\n"; + PrintIndent(); + stream << "int32_t shift = " << shift << ";\n"; + PrintIndent(); + stream << "int32_t diff_min = " << diff_min << ";\n"; + PrintIndent(); + stream << cmsis_func_name << "(buffer,"; + PrintIndent(); + stream << " num_rows, row_size, mult, shift, diff_min, buffer1);\n"; + PrintIndent(); + stream << "return;\n"; + } +}; + +class CMSISNNModuleNode : public runtime::ModuleNode { + public: + CMSISNNModuleNode(const std::string& code, const std::string& fmt, + const Array& func_names) + : code_(code), fmt_(fmt), func_names_(func_names) {} + + std::string GetSource(const std::string& format) final { return code_; } + + const char* type_key() const { return "c"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); + } else { + return PackedFunc(nullptr); + } + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + std::string meta_file = GetMetaFilePath(file_name); + if (fmt == "c" || fmt == "cu") { + ICHECK_NE(code_.length(), 0); + SaveBinaryToFile(file_name, code_); + } else { + ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + } + } + + protected: + std::string code_; + std::string fmt_; + Array func_names_; +}; + +class CMSISNNModule : public Module { + public: + CMSISNNModule() {} + explicit CMSISNNModule(ObjectPtr n) : Module(n) {} + inline CMSISNNModuleNode* operator->(); + inline const CMSISNNModuleNode* operator->() const; +}; + +inline CMSISNNModuleNode* CMSISNNModule::operator->() { + return static_cast(get_mutable()); +} + +static Module CMSISNNModuleNodeCreate(IRModule mod) { + bool output_ssa = false; + CodeGenCMSISNN cg; + Array function_names; + cg.Init(output_ssa); + ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc."; + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; + function_names.push_back(global_symbol.value()); + cg.AddFunction(f); + } + std::string code = cg.Finish(); + auto n = make_object(code, "c", function_names); + return Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.module.cmsisnn.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = CMSISNNModuleNodeCreate(args[0]); +}); + +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_cmsisnn/test_networks.py b/tests/python/contrib/test_cmsisnn/test_networks.py new file mode 100644 index 000000000000..1f6e0e711f0c --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_networks.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN: testing with networks""" + +import platform +import sys +import os +import pathlib +import tvm +from tvm import relay +from tvm.contrib.download import download_testdata +from tvm.relay.op.contrib import cmsisnn +import numpy as np +import pytest +import itertools + +from tests.python.relay.aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + generate_ref_data, + compile_and_run, +) + + +def get_range_for_dtype_str(dtype): + """ + Produce the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + + try: + type_info = np.iinfo(dtype) + except ValueError: + type_info = np.finfo(dtype) + return type_info.min, type_info.max + + +def convert_to_relay( + tflite_model_buf, + input_data, + input_node, +): + def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype.name + + mod, params = relay.frontend.from_tflite( + tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict + ) + + return mod, params + + +@pytest.mark.skipif( + platform.machine() == "i686", reason="Reference system unavailable in i386 container" +) +def test_cnn_small(): + # download the model + base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8" + file_to_download = "cnn_s_quantized.tflite" + model_file = download_testdata("{}/{}".format(base_url, file_to_download), file_to_download) + + with open(model_file, "rb") as f: + tflite_model_buf = f.read() + + input_shape = (1, 490) + in_min, in_max = get_range_for_dtype_str("int8") + input_data = np.random.randint(in_min, high=in_max, size=input_shape).astype(np.float32) + + orig_mod, params = convert_to_relay(tflite_model_buf, input_data, "input") + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate CMSIS-NN output against CPU output + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + inputs = {"input": input_data} + params = {} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py index afbc302af66f..c1951d1f2ce5 100644 --- a/tests/python/contrib/test_cmsisnn/test_softmax.py +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -17,16 +17,52 @@ """CMSIS-NN integration tests: softmax""" -import pytest +import platform import sys - +import os +import pathlib import tvm from tvm import relay from tvm.relay.op.contrib import cmsisnn import numpy as np +import pytest +import itertools + +from tests.python.relay.aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + generate_ref_data, + compile_and_run, +) + + +def get_range_for_dtype_str(dtype): + """ + Produce the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + + try: + type_info = np.iinfo(dtype) + except ValueError: + type_info = np.finfo(dtype) + return type_info.min, type_info.max def count_num_calls(mod): + """Count number of CallNode in the IRModule""" + class CallCounter(relay.ExprVisitor): def __init__(self): super().__init__() @@ -45,33 +81,50 @@ def visit_call(self, call): def make_module(func): + """Create IRModule from Function""" func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) return relay.transform.InferType()(mod) -def make_model(shape, zero_point, scale, in_dtype, out_dtype): - a = relay.var("a", shape=shape, dtype=in_dtype) +def make_model( + shape, in_dtype, out_dtype, in_zero_point, in_scale, out_zero_point=-128, out_scale=1.0 / 256 +): + + """Create a Relay Function / network model""" + a = relay.var("in0", shape=shape, dtype=in_dtype) dequantize = relay.qnn.op.dequantize( a, - input_scale=relay.const(scale, "float32"), - input_zero_point=relay.const(zero_point, "int32"), + input_scale=relay.const(in_scale, "float32"), + input_zero_point=relay.const(in_zero_point, "int32"), ) softmax = relay.nn.softmax(dequantize) model = relay.qnn.op.quantize( softmax, - output_scale=relay.const(scale, "float32"), - output_zero_point=relay.const(zero_point, "int32"), + output_scale=relay.const(out_scale, "float32"), + output_zero_point=relay.const(out_zero_point, "int32"), out_dtype=out_dtype, ) return model -def test_softmax_int8(): - model = make_model([1, 16, 16, 3], 64, 0.02, "int8", "int8") +@pytest.mark.skipif( + platform.machine() == "i686", reason="Reference system unavailable in i386 container" +) +@pytest.mark.parametrize(["zero_point", "scale"], [[33, 0.256], [-64, 0.0128]]) +def test_softmax_int8(zero_point, scale): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + dtype = "int8" + shape = [1, 16, 16, 3] + model = make_model(shape, dtype, dtype, zero_point, scale) orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + # validate pattern matching attrs = [ cmsisnn_mod[var.name_hint].attrs for var in cmsisnn_mod.get_global_vars() @@ -88,10 +141,52 @@ def test_softmax_int8(): cmsisnn_mod ), "Number of calls changed during partitioning" + # validate the output + in_min, in_max = get_range_for_dtype_str(dtype) + np.random.seed(0) + input_data = np.random.randint(in_min, high=in_max, size=shape, dtype=dtype) + inputs = {"in0": input_data} + params = {} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + ) + + +def parameterize_for_invalid_model(test): + in_dtype = ["uint8", "int8"] + out_dtype = ["uint8", "int8"] + zero_point = [-128, 64] + scale = [1.0 / 256, 0.2] + out_zero_point = [-128, 33] + out_scale = [1.0 / 256, 0.2] + all_combinations = itertools.product( + in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale + ) + all_combinations = filter( + lambda parameters: not ( + parameters[0] == "int8" + and parameters[1] == "int8" + and parameters[4] == -128 + and parameters[5] == 1.0 / 256 + ), + all_combinations, + ) + return pytest.mark.parametrize( + ["in_dtype", "out_dtype", "zero_point", "scale", "out_zero_point", "out_scale"], + all_combinations, + )(test) + + +@parameterize_for_invalid_model +def test_invalid_softmax(in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale): + model = make_model( + [1, 16, 16, 3], in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale + ) -@pytest.mark.parametrize("in_dtype,out_dtype", [["uint8", "int8"], ["int8", "uint8"]]) -def test_softmax_not_int8(in_dtype, out_dtype): - model = make_model([1, 16, 16, 3], 64, 0.02, in_dtype, out_dtype) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 209c371a296a..d1e090f40bc5 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -189,3 +189,13 @@ def tflite_mobilenet_v1_0_25_128(tmpdir_factory): ) return model_file + + +@pytest.fixture(scope="session") +def tflite_cnn_s_quantized(tmpdir_factory): + base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8/" + file_to_download = "cnn_s_quantized.tflite" + model_file = download_testdata( + "{}/{}".format(base_url, file_to_download), file_to_download, module=["tvmc"] + ) + return model_file diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 145713298dcf..defd628c60c9 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. import os +import re import shutil +import tarfile from os import path from unittest import mock @@ -305,6 +307,37 @@ def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant assert os.path.exists(dumps_path) +def test_compile_tflite_module_with_external_codegen_cmsisnn( + tmpdir_factory, tflite_cnn_s_quantized +): + pytest.importorskip("tflite") + + output_dir = tmpdir_factory.mktemp("mlf") + tvmc_model = tvmc.load(tflite_cnn_s_quantized) + + output_file_name = f"{output_dir}/file.tar" + + tvmc_package = tvmc.compiler.compile_model( + tvmc_model, + target=f"cmsis-nn, c -runtime=c --system-lib --link-params -mcpu=cortex-m55 --executor=aot", + output_format="mlf", + package_path=output_file_name, + pass_context_configs=["tir.disable_vectorize=true"], + ) + + # check whether an MLF package was created + assert os.path.exists(output_file_name) + + # check whether the expected number of C sources are in the tarfile + with tarfile.open(output_file_name) as mlf_package: + c_source_files = [ + name + for name in mlf_package.getnames() + if re.match(r"\./codegen/host/src/\D+\d+\.c", name) + ] + assert len(c_source_files) == 3 + + @pytest.mark.skipif( not vitis_ai_available(), reason="--target=vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 167e5becd4a7..4c9d50ec90bb 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -43,6 +43,7 @@ echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake echo set\(USE_ETHOSN /opt/arm/ethosn-driver\) >> config.cmake echo set\(USE_ETHOSN_HW OFF\) >> config.cmake +echo set\(USE_CMSISNN ON\) >> config.cmake echo set\(USE_VITIS_AI ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE ON\) >> config.cmake