diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index 7af47c3a81a1..e7bbfb630a72 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -57,6 +57,7 @@ def partition_for_cmsisnn(mod, params=None, **opts): transform.AnnotateTarget("cmsis-nn"), transform.PartitionGraph(), GenerateCMSISNNConstants(), + ScalarToTensorConstants(), ExtractConstantsFromPartitionedFunction(), transform.InferType(), ] @@ -223,11 +224,23 @@ def binary_op_pattern(op): is_constant(), ) - def check_qnn_binary_op(extract): + def check_qnn_binary_op(pattern): """Check if multiply is supported by CMSIS-NN.""" + arg0 = pattern.args[0] + arg1 = pattern.args[1] + both_args_scalar = False + if ( + isinstance(arg0, tvm.relay.expr.Constant) + and len(arg0.checked_type.shape) == 0 + and isinstance(arg1, tvm.relay.expr.Constant) + and len(arg1.checked_type.shape) == 0 + ): + both_args_scalar = True + return ( - extract.args[0].checked_type.dtype == "int8" - and extract.args[1].checked_type.dtype == "int8" + arg0.checked_type.dtype == "int8" + and arg1.checked_type.dtype == "int8" + and not both_args_scalar ) return [ diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 9b724034ccf2..1cbe36e30f76 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -62,17 +62,88 @@ class ExtractConstantsMutator : public MixedModeMutator { return func; } - function_to_constants_.Set(func, Array{}); + function_to_arguments_.Set(func, Array{}); functions_.push_back(func); auto new_body = VisitExpr(func->body); functions_.pop_back(); - if (function_to_constants_[func].size()) { + if (function_to_arguments_[func].size()) { func = WithFields(func, FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); } return std::move(func); } + // Creates new arguments from current call's arguments + // Updates constants into the caller arguments: here caller signifies caller that comprises call + // to func + Array CreateNewCallArgsFromExtractedConstants(Call call, Function func) { + ICHECK(function_to_arguments_.find(func) != function_to_arguments_.end()); + Array function_signature(function_to_arguments_[func]); + + // Is func a global_function? + // main() is not registered for extracting constants + bool is_global_function = functions_.empty() ? true : false; + + bool new_constants_added = false; + // This tracks arguments traversed inside function_signature + uint32_t function_signature_id = 0; + // This contains arguments including constants for the caller of this function inside which + // post_call resides. + Array new_caller_args; + // New arguments to post_call that includes new variables representing constants extracted from + // the function + Array new_call_args; + for (auto& arg : call->args) { + if (auto* constant = arg.as()) { + new_caller_args.push_back(arg); + new_call_args.push_back(Var(gen_var_name(), constant->tensor_type())); + ++function_signature_id; + new_constants_added = true; + continue; + } + + // Push all constants from the function_signature until a variable corresponding to the + // current argument is hit + while (function_signature_id < function_signature.size()) { + auto* constant = function_signature[function_signature_id].as(); + if (constant == nullptr) { + break; + } + new_caller_args.push_back(function_signature[function_signature_id++]); + new_call_args.push_back(Var(gen_var_name(), constant->tensor_type())); + new_constants_added = true; + } + + new_call_args.push_back(arg); + if (is_global_function || arg.as()) { + new_caller_args.push_back(arg); + } + ++function_signature_id; + } + + // Push remaining constants as new arguments + for (uint32_t i = function_signature_id; i < function_signature.size(); ++i) { + auto* constant = function_signature[i].as(); + ICHECK(constant) + << "Rest of the collected arguments should be constant in the partitioned function."; + new_caller_args.push_back(GetRef(constant)); + new_call_args.push_back(Var(gen_var_name(), constant->tensor_type())); + new_constants_added = true; + } + + // Update the arguments of caller of local function + if (new_constants_added && !is_global_function) { + const Function& last_func = functions_.back(); + Array function_constants(function_to_arguments_[last_func]); + function_to_arguments_.Set(last_func, + tvm::runtime::Concat(function_constants, new_caller_args)); + } else { + new_call_args = new_caller_args; + } + + return new_call_args; + } + Expr Rewrite_(const CallNode* call, const Expr& post) final { Expr final_call = post; auto* post_call = post.as(); @@ -81,23 +152,28 @@ class ExtractConstantsMutator : public MixedModeMutator { // Perform this for non-main Call Nodes only if (!functions_.empty() && call->op.as()) { Array new_args; + const Function& last_func = functions_.back(); + Array function_signature(function_to_arguments_[last_func]); for (auto& arg : post_call->args) { + // Push all arguments including constants to maintain correct order of + // variables and constants auto* const_arg = arg.as(); if (const_arg && !const_arg->is_scalar()) { Var var_arg = Var(gen_var_name(), const_arg->tensor_type()); new_args.push_back(var_arg); - const Function& last_func = functions_.back(); - Array fconstants(function_to_constants_[last_func]); - fconstants.push_back(GetRef(const_arg)); - function_to_constants_.Set(last_func, fconstants); + function_signature.push_back(arg); } else { + if (arg.as()) { + function_signature.push_back(arg); + } new_args.push_back(arg); } } + function_to_arguments_.Set(last_func, function_signature); final_call = Call(call->op, new_args, call->attrs, {}); } - // Since the constants are kicked out of partitioned functions + // Since the constants are extracted from partitioned functions // a new call to global function is needed if (auto* glob_var_node = post_call->op.as()) { auto glob_var = GetRef(glob_var_node); @@ -105,34 +181,18 @@ class ExtractConstantsMutator : public MixedModeMutator { auto new_glob_func = VisitExpr(glob_func); if (!new_glob_func.same_as(glob_func)) { mod_->Update(glob_var, Downcast(new_glob_func)); - Array new_args = post_call->args; - ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end()); - for (auto constant : function_to_constants_.at(glob_func)) { - new_args.push_back(constant); - } + auto new_args = CreateNewCallArgsFromExtractedConstants(GetRef(post_call), glob_func); final_call = Call(glob_var, new_args); } } - // Since the constants are kicked out of the local partitioned functions + // Since the constants are extracted from the local partitioned functions // a new call to local function is needed - // Also, pass on the constants to the callee of this function to support nested functions if (auto* func_node = call->op.as()) { Function func = GetRef(func_node); auto new_func = VisitExpr(func); - if (!new_func.same_as(func)) { - Array new_args = post_call->args; - ICHECK(function_to_constants_.find(func) != function_to_constants_.end()); - const Function& last_func = functions_.back(); - Array fconstants(function_to_constants_[last_func]); - for (auto constant : function_to_constants_.at(func)) { - fconstants.push_back(constant); - Var var_arg = Var(gen_var_name(), constant->tensor_type()); - new_args.push_back(var_arg); - } - function_to_constants_.Set(last_func, fconstants); - final_call = Call(new_func, new_args); - } + Array new_args = CreateNewCallArgsFromExtractedConstants(GetRef(post_call), func); + final_call = Call(new_func, new_args); } return final_call; @@ -141,15 +201,16 @@ class ExtractConstantsMutator : public MixedModeMutator { private: /* \brief Updated module where all calls have replaced constants with new variables */ IRModule mod_; - /* \brief Maintains mapping of original function to the replaced constants */ - Map> function_to_constants_; - /* \brief Stack of functions to determine scope while filling up function_to_constants_ */ + /* \brief Maintains mapping of original function to the replaced constants along with other + * arguments to retain the order in which variables are used within the function */ + Map> function_to_arguments_; + /* \brief Stack of functions to determine scope while filling up function_to_arguments_ */ Array functions_; /* \brief Keeps track of variables being created */ int var_count_ = 0; }; -/*! * \brief Kicks out all constants out of the partitioned function into main() */ +/*! * \brief Extracts all constants out of the partitioned function into main() */ IRModule ExtractConstants(const IRModule& mod) { String func_name; Function func; @@ -169,7 +230,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() { runtime::TypedPackedFunc pass_func = [=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); }; return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction", - {}); + {"InferType"}); } TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction") diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc new file mode 100644 index 000000000000..24ba0738be68 --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file scalar_to_tensor_constant.cc + * \brief Converts scalar constant into tensor constant for binary ops of CMSIS-NN + */ + +#include +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +/*! + * \brief This Mutator finds all partitioned functions meant for CMSIS-NN binary ops. + * Then, it substitutes the scalar constants with tensor constants. It makes the shape of this + * new constant same as that of the neighbouring constant of the other binary operand. The + * expectation is that the ExtractConstant pass would later extract this tensor constant out of the + * global partitioned function, thus making the entire global partitioned and its composite function + * constant free. This makes the TIR generation for binary ops via CMSIS-NN independent of + * constants. + */ +class ScalarToTensorConstantMutator : public MixedModeMutator { + public: + explicit ScalarToTensorConstantMutator(const IRModule& mod) : mod_(mod) {} + + private: + using MixedModeMutator::VisitExpr_; + + // Here is an example with the annotated scalar constant: + // def @tvmgen_default_cmsis_nn_main_1(%cmsis_nn_input: Tensor[], Inline=1, Compiler="cmsis-nn", + // global_symbol="tvmgen_default_cmsis_nn_main", + // Primitive=1) -> Tensor[] { + // %56 = fn (%input0: _scalar_constant_, %input1: Tensor[], + // PartitionedFromPattern="qnn.mul_", Composite="cmsis-nn.qnn_mul") -> Tensor[] { + // qnn.mul(%input0, %input1, scale0, zero_point0, + // scale1, zero_point_1, output_scale, output_zero_point) + // }; + // %56(meta[relay.Constant] /* _scalar constant_ */, %cmsis-nn_input) + // } + Expr Rewrite_(const CallNode* call, const Expr& post) final { + Expr final_call = post; + call = post.as(); + + // Create a new variable argument that is of the same shape as the neighbouring argument + // in the binary op. This needs to be done only when one of the arguments is a scalar. + if (call->op.as()) { + final_call = ReplaceScalarWithTensorVariable(GetRef(call)); + } + + if (auto* glob_var_node = call->op.as()) { + GlobalVar global_var = GetRef(glob_var_node); + Function func = Downcast(mod_->Lookup(global_var)); + auto compiler_name = func->GetAttr(::tvm::relay::attr::kCompiler); + if (!compiler_name.defined() || compiler_name != "cmsis-nn") { + return final_call; + } + auto new_body = VisitExpr(func->body); + if (new_body.same_as(func->body)) { + return final_call; + } + Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); + mod_->Update(global_var, new_func); + final_call = Call(global_var, call->args); + } + + // Substitute scalar constant with a tensor constant in the call to composite function + // comprising partitioned binary ops. Shape of the new constant should be same as its + // neighbouring tensor's shape. + if (auto* func_node = call->op.as()) { + Function func = GetRef(func_node); + auto func_name = func->GetAttr(attr::kComposite); + if (func_name.defined() && + (func_name == "cmsis-nn.qnn_add" || func_name == "cmsis-nn.qnn_mul")) { + final_call = ReplaceScalarWithTensorConstant(GetRef(call), func); + } + } + + return final_call; + } + + // Replaces scalar variable with a tensor variable with same shape as that of the neibouring + // operand tensor in a binary op + Call ReplaceScalarWithTensorVariable(Call call) { + const OpNode* opnode = call->op.as(); + if (opnode == nullptr) { + return call; + } + String op_name = opnode->name; + Array new_args; + for (uint32_t i = 0; i < call->args.size(); ++i) { + Expr arg = call->args[i]; + new_args.push_back(arg); + if (!arg->checked_type_.defined()) { + continue; + } + auto* arg_type = arg->type_as(); + if (arg_type->shape.size() != 0 || arg.as()) { + continue; + } + String arg_name = arg.as()->name_hint(); + int tensor_arg_id = (i + 1) % 2; + Expr tensor_arg = call->args[tensor_arg_id]; + if (!tensor_arg->checked_type_.defined()) { + continue; + } + TensorType tensor_type = GetRef(tensor_arg->type_as()); + new_args.Set(i, Var(arg_name, tensor_type)); + } + return Call(call->op, new_args, call->attrs, {}); + } + + // Makes tensor constant of same shape as tensor_arg with values from scalar_arg + Call ReplaceScalarWithTensorConstant(Call call, Function func) { + Array new_args; + for (uint32_t i = 0; i < call->args.size(); ++i) { + new_args.push_back(call->args[i]); + Expr scalar_arg = call->args[i]; + if (!scalar_arg->checked_type_.defined()) { + continue; + } + Array scalar_shape = scalar_arg->type_as()->shape; + if (scalar_shape.size() != 0 || scalar_arg.as() == nullptr) { + continue; + } + int tensor_arg_id = (i + 1) % 2; + Expr tensor_arg = call->args[tensor_arg_id]; + if (!tensor_arg->checked_type_.defined()) { + continue; + } + TensorType tensor_type = GetRef(tensor_arg->type_as()); + std::vector tensor_shape; + for (auto& dim : tensor_type->shape) { + tensor_shape.push_back(qnn::get_const_int(dim)); + } + int8_t scalar_value = GetScalarFromConstant(scalar_arg); + int tensor_num_elements = qnn::get_const_int(tensor_type->Size()); + std::vector tensor_values(tensor_num_elements, scalar_value); + Constant tensor_constant = + MakeConstantTensor(DataType::Int(8), tensor_shape, tensor_values); + new_args.Set(i, tensor_constant); + } + auto new_body = VisitExpr(func->body); + Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); + return Call(new_func, new_args); + } + + private: + IRModule mod_; +}; + +IRModule ScalarToTensorConstant(const IRModule& mod) { + auto mutator = ScalarToTensorConstantMutator(mod); + Function main_func = Downcast(mod->Lookup("main")); + auto new_main_body = mutator.VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + auto main_var = mod->GetGlobalVar("main"); + auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, + main_func->type_params, main_func->attrs); + mod->Update(main_var, new_main_func); + } + return mod; +} + +transform::Pass ScalarToTensorConstantPass() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, transform::PassContext pc) { return ScalarToTensorConstant(m); }; + return tvm::transform::CreateModulePass(pass_func, 0, "ScalarToTensorConstant", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ScalarToTensorConstants") + .set_body_typed(ScalarToTensorConstantPass); + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_cmsisnn/test_binary_ops.py b/tests/python/contrib/test_cmsisnn/test_binary_ops.py index 39b8c5fd21c6..f6417acbe613 100644 --- a/tests/python/contrib/test_cmsisnn/test_binary_ops.py +++ b/tests/python/contrib/test_cmsisnn/test_binary_ops.py @@ -17,9 +17,11 @@ """CMSIS-NN integration tests: binary ops""" +import itertools import sys import numpy as np +from enum import Enum import pytest import tvm @@ -35,11 +37,29 @@ ) +def generate_tensor_constant(): + rng = np.random.default_rng(12321) + dtype = "int8" + shape = (1, 16, 16, 3) + values = tvm.nd.array( + rng.integers(np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=shape, dtype=dtype) + ) + return relay.const(values, dtype) + + +def generate_scalar_constant(): + dtype = "int8" + return relay.const(-30, dtype) + + +def generate_variable(name, dtype="int8"): + return relay.var(name, shape=(1, 16, 16, 3), dtype=dtype) + + def make_model( op, - shape, - input_0_dtype, - input_1_dtype, + input_0, + input_1, input_0_scale, input_0_zero_point, input_1_scale, @@ -48,10 +68,9 @@ def make_model( out_zero_point=-128, ): """Create a Relay Function / network model""" - return op( - relay.var("input_0", shape=shape, dtype=input_0_dtype), - relay.var("input_1", shape=shape, dtype=input_1_dtype), + input_0, + input_1, relay.const(input_0_scale, "float32"), relay.const(input_0_zero_point, "int32"), relay.const(input_1_scale, "float32"), @@ -82,9 +101,8 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z shape = [1, 16, 16, 3] model = make_model( op, - shape, - dtype, - dtype, + generate_variable("input_0"), + generate_variable("input_1"), input_0_scale, input_0_zero_point, input_1_scale, @@ -131,6 +149,128 @@ def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_z ) +# At least one of the inputs is a constant, both can't be variables, both can't be scalars +def parameterize_for_constant_inputs(test): + op = [relay.qnn.op.mul, relay.qnn.op.add] + input_0 = [generate_variable("input_0"), generate_tensor_constant(), generate_scalar_constant()] + input_1 = [generate_variable("input_1"), generate_tensor_constant(), generate_scalar_constant()] + all_combinations = itertools.product(op, input_0, input_1) + all_combinations = filter( + lambda parameters: not ( + ( + isinstance(parameters[1], tvm.relay.expr.Var) + and isinstance(parameters[2], tvm.relay.expr.Var) + ) + or ( + isinstance(parameters[1], tvm.relay.expr.Constant) + and isinstance(parameters[2], tvm.relay.expr.Constant) + and parameters[1].data.numpy().ndim == 0 + and parameters[2].data.numpy().ndim == 0 + ) + ), + all_combinations, + ) + return pytest.mark.parametrize( + ["op", "input_0", "input_1"], + all_combinations, + )(test) + + +@skip_if_no_reference_system +@tvm.testing.requires_cmsisnn +@parameterize_for_constant_inputs +def test_constant_input_int8(op, input_0, input_1): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + dtype = "int8" + shape = [1, 16, 16, 3] + input_0_scale = 0.256 + input_0_zero_point = 33 + input_1_scale = 0.128 + input_1_zero_point = -24 + model = make_model( + op, + input_0, + input_1, + input_0_scale, + input_0_zero_point, + input_1_scale, + input_1_zero_point, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + in_min, in_max = get_range_for_dtype_str(dtype) + inputs = {} + if isinstance(input_0, tvm.relay.expr.Var): + inputs.update({"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)}) + if isinstance(input_1, tvm.relay.expr.Var): + inputs.update({"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype)}) + output_list = generate_ref_data(orig_mod["main"], inputs) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@skip_if_no_reference_system +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) +def test_both_scalar_inputs_int8( + op, +): + input_scale = 0.256 + input_zero_point = 33 + dtype = "int8" + model = make_model( + op, + generate_scalar_constant(), + generate_scalar_constant(), + input_scale, + input_zero_point, + input_scale, + input_zero_point, + ) + + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + @skip_if_no_reference_system @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add]) @@ -143,9 +283,8 @@ def test_invalid_parameters( input_zero_point = 33 model = make_model( op, - [1, 16, 16, 3], - input_dtype, - input_dtype, + generate_variable("input_0", input_dtype), + generate_variable("input_1", input_dtype), input_scale, input_zero_point, input_scale, diff --git a/tests/python/contrib/test_cmsisnn/test_extract_constants.py b/tests/python/contrib/test_cmsisnn/test_extract_constants.py index ca3fbe6b20ed..8e251777716a 100644 --- a/tests/python/contrib/test_cmsisnn/test_extract_constants.py +++ b/tests/python/contrib/test_cmsisnn/test_extract_constants.py @@ -23,15 +23,6 @@ import tvm from tvm import relay -from utils import ( - make_module, - count_num_calls, - get_range_for_dtype_str, - get_same_padding, - get_conv2d_qnn_params, - make_qnn_relu, -) - tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) @@ -136,7 +127,6 @@ def test_multiple_functions(): c10 = relay.Call(f20, [x10]) c11 = relay.Call(f21, [c10]) ef = relay.Function([x10], c11, relay.TensorType((8, 8), "float32")) - x0 = relay.var("x0", shape=(8, 8)) ev = relay.GlobalVar("cmsis-nn") ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) @@ -182,56 +172,42 @@ def test_main_function(): ), "main() should have same number of arguments as before" -def parameterize_for_invalid_model(test): - local_func_1 = ["cmsis-nn.qnn_op_1", "local_function_1"] - local_func_2 = ["cmsis-nn.qnn_op_2", "local_function_2"] - compiler_name = ["cmsis-nn", "external_compiler"] - all_combinations = itertools.product(local_func_1, local_func_2, compiler_name) - all_combinations = filter( - lambda parameters: not ( - parameters[2] == "cmsis-nn" - and parameters[0] == "cmsis-nn.qnn_op_1" - and parameters[1] == "cmsis-nn.qnn_op_2" - ), - all_combinations, - ) - return pytest.mark.parametrize( - ["func_name_1", "func_name_2", "external_compiler"], - all_combinations, - )(test) - - @tvm.testing.requires_cmsisnn -@parameterize_for_invalid_model -def test_multiple_functions_non_cmsisnn_compiler(func_name_1, func_name_2, external_compiler): +@pytest.mark.parametrize("external_compiler", ["cmsis-nn", "other_compiler"]) +def test_multiple_functions_non_cmsisnn_compiler(external_compiler): y20_data = np.random.uniform(0, 1, (8, 8)).astype("float32") x20 = relay.var("x20", shape=(8, 8)) y20_const = relay.const(y20_data, "float32") z20 = x20 + y20_const f20 = relay.Function([x20], z20, relay.TensorType((8, 8), "float32")) - f20 = set_composite_func_attr(f20, func_name_1) + f20 = set_composite_func_attr(f20, "cmsis-nn.qnn_op_1") + x10 = relay.var("x10", shape=(8, 8)) + c10 = relay.Call(f20, [x10]) + ef0 = relay.Function([x10], c10, relay.TensorType((8, 8), "float32")) y21_data = np.random.uniform(0, 1, (8, 8)).astype("float32") x21 = relay.var("x21", shape=(8, 8)) y21_const = relay.const(y21_data, "float32") z21 = x21 + y21_const f21 = relay.Function([x21], z21, relay.TensorType((8, 8), "float32")) - f21 = set_composite_func_attr(f21, func_name_2) - - x10 = relay.var("x10", shape=(8, 8)) - c10 = relay.Call(f20, [x10]) - c11 = relay.Call(f21, [c10]) - ef = relay.Function([x10], c11, relay.TensorType((8, 8), "float32")) + f21 = set_composite_func_attr(f21, "cmsis-nn.qnn_op_2") + x11 = relay.var("x11", shape=(8, 8)) + c11 = relay.Call(f21, [x11]) + ef1 = relay.Function([x11], c11, relay.TensorType((8, 8), "float32")) x0 = relay.var("x0", shape=(8, 8)) - ev = relay.GlobalVar("external_function") - ef = set_external_func_attr(ef, external_compiler, ev.name_hint) - c = relay.Call(ev, [x0]) - mf = relay.Function([x0], c, relay.TensorType((8, 8), "float32")) + ev0 = relay.GlobalVar("external_function_0") + ef0 = set_external_func_attr(ef0, external_compiler, ev0.name_hint) + c0 = relay.Call(ev0, [x0]) + ev1 = relay.GlobalVar("external_function_1") + ef1 = set_external_func_attr(ef1, external_compiler, ev1.name_hint) + c1 = relay.Call(ev1, [c0]) + mf = relay.Function([x0], c1, relay.TensorType((8, 8), "float32")) mv = relay.GlobalVar("main") mod = tvm.IRModule() - mod[ev] = ef + mod[ev0] = ef0 + mod[ev1] = ef1 mod[mv] = mf mod = ExtractConstantsFromPartitionedFunction()(mod) @@ -240,10 +216,7 @@ def test_multiple_functions_non_cmsisnn_compiler(func_name_1, func_name_2, exter num_extracted_constants = 0 if external_compiler == "cmsis-nn": - if "cmsis-nn" in func_name_1: - num_extracted_constants += 1 - if "cmsis-nn" in func_name_2: - num_extracted_constants += 1 + num_extracted_constants = 2 assert ( check_for_constants.num_constants_ == num_extracted_constants diff --git a/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py new file mode 100644 index 000000000000..703961728f28 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: scalar_to_tensor_constant pass""" +import itertools +import math +import numpy as np +import pytest +import tvm +from tvm import relay + +tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__) + + +class CheckFunctionsForConstants(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.num_constants_ = 0 + + def visit_call(self, call): + super().visit_call(call) + for arg in call.args: + if isinstance(arg, relay.Constant) and arg.data.numpy().ndim > 0: + self.num_constants_ += 1 + + def check_num_constants(self, func): + assert self.num_constants_ == 0, "Functions should not have constant arguments in Calls" + + +def set_external_func_attr(func, compiler, ext_symbol): + func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.with_attr("Compiler", compiler) + func = func.with_attr("global_symbol", ext_symbol) + return func + + +def set_composite_func_attr(func, name): + func = func.with_attr("Composite", name) + return func + + +@tvm.testing.requires_cmsisnn +def test_single_scalar_position_0(): + x0 = relay.var("x0", shape=None) + x1 = relay.var("x1", shape=(8, 8)) + z1 = x0 + x1 + lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32")) + lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + + y0 = relay.expr.const(3, "float32") + y1 = relay.var("y1", shape=(8, 8)) + c0 = relay.Call(lf, [y0, y1]) + ef = relay.Function([y1], c0, relay.TensorType((8, 8), "float32")) + + x = relay.var("x", shape=(8, 8)) + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) + c = relay.Call(ev, [x]) + mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = relay.transform.InferType()(mod) + mod = ScalarToTensorConstants()(mod) + check_for_constants = CheckFunctionsForConstants() + check_for_constants.visit_call(mod[ev].body) + assert ( + check_for_constants.num_constants_ == 1 + ), "Scalar constant wasn't converted into tensor constant" + + +@tvm.testing.requires_cmsisnn +def test_single_scalar_position_1(): + x0 = relay.var("x0", shape=(8, 8)) + x1 = relay.var("x1", shape=None) + z1 = x0 + x1 + lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32")) + lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + + y0 = relay.var("y0", shape=(8, 8)) + y1 = relay.expr.const(3, "float32") + c0 = relay.Call(lf, [y0, y1]) + ef = relay.Function([y0], c0, relay.TensorType((8, 8), "float32")) + + x = relay.var("x", shape=(8, 8)) + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) + c = relay.Call(ev, [x]) + mf = relay.Function([x], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = relay.transform.InferType()(mod) + mod = ScalarToTensorConstants()(mod) + check_for_constants = CheckFunctionsForConstants() + check_for_constants.visit_call(mod[ev].body) + assert ( + check_for_constants.num_constants_ == 1 + ), "Scalar constant wasn't converted into tensor constant" + + +@tvm.testing.requires_cmsisnn +def test_two_scalars(): + x1 = relay.var("x1", shape=None) + x2 = relay.var("x2", shape=None) + z1 = x1 + x2 + lf = relay.Function([x1, x2], z1, relay.TensorType((), "float32")) + lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + + y0 = relay.expr.const(5, "float32") + y1 = relay.expr.const(3, "float32") + c0 = relay.Call(lf, [y0, y1]) + ef = relay.Function([], c0, relay.TensorType((), "float32")) + + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) + c = relay.Call(ev, []) + mf = relay.Function([], c, relay.TensorType((), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = relay.transform.InferType()(mod) + mod = ScalarToTensorConstants()(mod) + check_for_constants = CheckFunctionsForConstants() + check_for_constants.visit_call(mod[ev].body) + assert ( + check_for_constants.num_constants_ == 0 + ), "Scalar constant wasn't converted into tensor constant" + + +@tvm.testing.requires_cmsisnn +def test_two_tensor_constants(): + x0 = relay.var("x0", shape=(8, 8)) + x1 = relay.var("x1", shape=(8, 8)) + z1 = x0 + x1 + lf = relay.Function([x0, x1], z1, relay.TensorType((8, 8), "float32")) + lf = set_composite_func_attr(lf, "cmsis-nn.qnn_add") + + y0 = relay.const(np.random.uniform(0, 1, (8, 8)).astype("float32"), "float32") + y1 = relay.const(np.random.uniform(0, 1, (8, 8)).astype("float32"), "float32") + c0 = relay.Call(lf, [y0, y1]) + ef = relay.Function([], c0, relay.TensorType((8, 8), "float32")) + + ev = relay.GlobalVar("external_function") + ef = set_external_func_attr(ef, "cmsis-nn", ev.name_hint) + c = relay.Call(ev, []) + mf = relay.Function([], c, relay.TensorType((8, 8), "float32")) + mv = relay.GlobalVar("main") + + mod = tvm.IRModule() + mod[ev] = ef + mod[mv] = mf + + mod = relay.transform.InferType()(mod) + mod = ScalarToTensorConstants()(mod) + check_for_constants = CheckFunctionsForConstants() + check_for_constants.visit_call(mod[ev].body) + assert ( + check_for_constants.num_constants_ == 2 + ), "Scalar constant wasn't converted into tensor constant" + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))