From ebb7408d72f9f8b4010fd89d05446b17c15775ac Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Fri, 1 Jul 2022 18:49:54 +0100 Subject: [PATCH 1/3] [CMSIS-NN][Perf] Converted Relay Conv2D into CMSIS-NN Depthwise Change-Id: I49cd262ce057b2a314f56aadca9b12995a519e88 --- .../contrib/cmsisnn/generate_constants.cc | 7 +- .../backend/contrib/cmsisnn/relay_to_tir.cc | 13 ++- src/relay/backend/contrib/cmsisnn/utils.cc | 47 ++++++++ src/relay/backend/contrib/cmsisnn/utils.h | 60 ++++++++++ .../contrib/test_cmsisnn/test_conv2d.py | 109 +++++++++++++++++- 5 files changed, 224 insertions(+), 12 deletions(-) create mode 100644 src/relay/backend/contrib/cmsisnn/utils.cc create mode 100644 src/relay/backend/contrib/cmsisnn/utils.h diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 450bcf26d1b3..ab673b500cf2 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -31,6 +31,7 @@ #include "../../../op/make_op.h" #include "../../../qnn/utils.h" #include "../../../transforms/pattern_utils.h" +#include "utils.h" namespace tvm { namespace relay { @@ -111,11 +112,7 @@ class GenerateConstantsMutator : public MixedModeMutator { Array input_shape = conv2d_call->args[0]->type_as()->shape; Array kernel_shape = conv2d_call->args[1]->type_as()->shape; - std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); - int kernel_pos_o = kernel_layout.find("O"); - int groups = conv2d_attrs->groups; - if (groups != qnn::get_const_int(input_shape[3]) || - groups != qnn::get_const_int(kernel_shape[kernel_pos_o])) { + if (!is_cmsisnn_depthwise(conv2d_attrs, input_shape, kernel_shape)) { // Transpose weights: HWIO -> OHWI for Conv2D conv2d_kernel = ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs); } diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 5c99061fa854..d60e50317e4c 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -31,6 +30,7 @@ #include "../../../transforms/pattern_utils.h" #include "buffer_size.h" #include "compiler_attrs.h" +#include "utils.h" namespace tvm { namespace relay { @@ -173,7 +173,6 @@ class RelayToTIRVisitor : public MixedModeMutator { int32_t dilation_w = qnn::get_const_int(conv2d_attrs->dilation[1]); int32_t dilation_h = qnn::get_const_int(conv2d_attrs->dilation[0]); int32_t out_channels = qnn::get_const_int(conv2d_attrs->channels); - int32_t groups = conv2d_attrs->groups; std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); int32_t clip_min = std::numeric_limits::min(); int32_t clip_max = std::numeric_limits::max(); @@ -207,11 +206,13 @@ class RelayToTIRVisitor : public MixedModeMutator { int32_t output_c = qnn::get_const_int(output_shape[3]); int32_t depth_multiplier = -1; - int kernel_pos_o = kernel_layout.find("O"); - if (groups == qnn::get_const_int(input_shape[3]) && - groups == qnn::get_const_int(filter_shape[kernel_pos_o])) { + if (is_cmsisnn_depthwise(conv2d_attrs, input_shape, filter_shape)) { + // Refer to TVM frontend to know how depth multiplier and out_channels are related + // https://github.com/apache/tvm/blob/6ed3ab3e33f8eafa4acaf53b7a671831de7587e9/python/tvm/relay/frontend/tflite.py#L2129 int kernel_pos_i = kernel_layout.find("I"); - depth_multiplier = qnn::get_const_int(filter_shape[kernel_pos_i]); + int kernel_pos_o = kernel_layout.find("O"); + int kernel_pos_dm = input_c == 1 ? kernel_pos_o : kernel_pos_i; + depth_multiplier = qnn::get_const_int(filter_shape[kernel_pos_dm]); } scalar_args.push_back(ToArg(depth_multiplier)); diff --git a/src/relay/backend/contrib/cmsisnn/utils.cc b/src/relay/backend/contrib/cmsisnn/utils.cc new file mode 100644 index 000000000000..a86bd51de105 --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/utils.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../qnn/utils.h" + +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +bool is_cmsisnn_depthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, + const Array& kernel_shape) { + std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); + int kernel_pos_o = kernel_layout.find("O"); + int kernel_pos_i = kernel_layout.find("I"); + int kernel_dim_o_val = qnn::get_const_int(kernel_shape[kernel_pos_o]); + int kernel_dim_i_val = qnn::get_const_int(kernel_shape[kernel_pos_i]); + int64_t out_channels = conv2d_attrs->channels.as()->value; + if (out_channels == kernel_dim_o_val * kernel_dim_i_val) { + return true; + } + return false; +} + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/cmsisnn/utils.h b/src/relay/backend/contrib/cmsisnn/utils.h new file mode 100644 index 000000000000..deb704fe038e --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/utils.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/cmsisnn/utils.h + * \brief CMSIS-NN utility functions + */ + +#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_UTILS_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "../../../op/make_op.h" +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { +/*! + * \brief Checks if Relay Conv2D was originally CMSIS-NN compliant Depthwise Convolution + * See: + * https://github.com/apache/tvm/blob/6ed3ab3e33f8eafa4acaf53b7a671831de7587e9/python/tvm/relay/frontend/tflite.py#L2107 + * + * + * \return true if a Conv2D is a Depthwise Convolution based on Conv2D's inputs' shapes and + * attributes + */ + +bool is_cmsisnn_depthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, + const Array& kernel_shape); + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_UTILS_H_ diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 462eb8834719..8ed0b01375a3 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -84,13 +84,14 @@ def make_model( ) ) weight_const = relay.const(weight, kernel_dtype) + conv2d_kernel_sc = kernel_scale[0] if out_channels == 1 else kernel_scale conv = relay.qnn.op.conv2d( invar, weight_const, input_zero_point=relay.const(input_zero_point, "int32"), kernel_zero_point=relay.const(kernel_zero_point, "int32"), input_scale=relay.const(input_scale, "float32"), - kernel_scale=relay.const(kernel_scale, "float32"), + kernel_scale=relay.const(conv2d_kernel_sc, "float32"), kernel_size=(kernel_h, kernel_w), data_layout="NHWC", kernel_layout=weight_format, @@ -105,6 +106,7 @@ def make_model( bias_const = relay.const(bias, "int32") last_op = relay.nn.bias_add(conv, bias_const, axis=3) if enable_bias else conv requant_input_sc = [sc * input_scale for sc in kernel_scale] + requant_input_sc = requant_input_sc[0] if out_channels == 1 else requant_input_sc last_op = relay.qnn.op.requantize( last_op, relay.const(requant_input_sc, "float32"), @@ -540,6 +542,111 @@ def test_depthwise_int8( ) +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))]) +@pytest.mark.parametrize("relu_type", ["RELU", "NONE"]) +@pytest.mark.parametrize("depth_multiplier", [1, 3]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale", + [ + ( + 10, + 0.0128, + [0.11, 0.22], + ), + ( + -64, + 1, + [1, 0.0256, 1.37], + ), + ], +) +def test_relay_conv2d_cmsisnn_depthwise_int8( + padding, + strides, + dilation, + relu_type, + input_zero_point, + input_scale, + kernel_scale, + depth_multiplier, +): + """Tests QNN Depthwise int8 op via CMSIS-NN""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + dtype = "int8" + in_min, in_max = get_range_for_dtype_str(dtype) + + ifm_shape = (1, 24, 24, 1) + groups = ifm_shape[3] + weight_format = "HWIO" + (kernel_h, kernel_w) = (3, 3) + kernel_shape = (kernel_h, kernel_w, ifm_shape[3], depth_multiplier) + out_channels = ifm_shape[3] * depth_multiplier + enable_bias = True + ks_len = len(kernel_scale) + kernel_zero_point = 0 + kernel_scale = [kernel_scale[i % ks_len] for i in range(out_channels)] + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + True, + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + def parameterize_for_invalid_model(test): """Generates non int8 inputs""" in_dtype = ["uint8", "int8"] From fea574bfda0fc39bac55668c3ec9c03e83f69741 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Tue, 5 Jul 2022 10:36:24 +0100 Subject: [PATCH 2/3] Included Depthwise CMSIS-NN sources in zephyr demo Change-Id: Ieace94bd0f29d92660bd8e3a2523c1bf42b6afaa --- apps/microtvm/zephyr_cmsisnn/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/apps/microtvm/zephyr_cmsisnn/CMakeLists.txt b/apps/microtvm/zephyr_cmsisnn/CMakeLists.txt index b09e1d0642d2..dd3582f86f7d 100644 --- a/apps/microtvm/zephyr_cmsisnn/CMakeLists.txt +++ b/apps/microtvm/zephyr_cmsisnn/CMakeLists.txt @@ -53,6 +53,11 @@ set(DATA_FILES ) set(CMSIS_SOURCES ${CMSIS_PATH}/CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_s8.c + ${CMSIS_PATH}/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_wrapper_s8.c + ${CMSIS_PATH}/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_s8.c + ${CMSIS_PATH}/CMSIS/NN/Source/ConvolutionFunctions/arm_depthwise_conv_s8_opt.c + ${CMSIS_PATH}/CMSIS/NN/Source/NNSupportFunctions/arm_nn_depthwise_conv_nt_t_s8.c + ${CMSIS_PATH}/CMSIS/NN/Source/NNSupportFunctions/arm_nn_depthwise_conv_nt_t_padded_s8.c ${CMSIS_PATH}/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c ${CMSIS_PATH}/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_1_x_n_s8.c ${CMSIS_PATH}/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_1x1_s8_fast.c From 7a0c89f7addd3edcbb9e42aea42310753e137523 Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi Date: Thu, 7 Jul 2022 16:55:06 +0100 Subject: [PATCH 3/3] C style function naming and additional check around Depthwise CMSIS-NN API Change-Id: Iec889ea4d2989617c085803aa9e72eb3296f2d35 --- .../cmsisnn/{utils.cc => convolutions.cc} | 17 ++++---- .../cmsisnn/{utils.h => convolutions.h} | 14 +++---- .../contrib/cmsisnn/generate_constants.cc | 4 +- .../backend/contrib/cmsisnn/relay_to_tir.cc | 4 +- .../contrib/test_cmsisnn/test_conv2d.py | 41 ++++++++++++++++--- 5 files changed, 54 insertions(+), 26 deletions(-) rename src/relay/backend/contrib/cmsisnn/{utils.cc => convolutions.cc} (80%) rename src/relay/backend/contrib/cmsisnn/{utils.h => convolutions.h} (78%) diff --git a/src/relay/backend/contrib/cmsisnn/utils.cc b/src/relay/backend/contrib/cmsisnn/convolutions.cc similarity index 80% rename from src/relay/backend/contrib/cmsisnn/utils.cc rename to src/relay/backend/contrib/cmsisnn/convolutions.cc index a86bd51de105..ebac83b81250 100644 --- a/src/relay/backend/contrib/cmsisnn/utils.cc +++ b/src/relay/backend/contrib/cmsisnn/convolutions.cc @@ -16,29 +16,28 @@ * specific language governing permissions and limitations * under the License. */ +#include "convolutions.h" -#include "../../../qnn/utils.h" +#include -#include -#include +#include "../../../qnn/utils.h" +#include "tvm/ir/transform.h" +#include "tvm/relay/attrs/nn.h" namespace tvm { namespace relay { namespace contrib { namespace cmsisnn { -bool is_cmsisnn_depthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, - const Array& kernel_shape) { +bool IsCMSISNNDepthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, + const Array& kernel_shape) { std::string kernel_layout = conv2d_attrs->kernel_layout.c_str(); int kernel_pos_o = kernel_layout.find("O"); int kernel_pos_i = kernel_layout.find("I"); int kernel_dim_o_val = qnn::get_const_int(kernel_shape[kernel_pos_o]); int kernel_dim_i_val = qnn::get_const_int(kernel_shape[kernel_pos_i]); int64_t out_channels = conv2d_attrs->channels.as()->value; - if (out_channels == kernel_dim_o_val * kernel_dim_i_val) { - return true; - } - return false; + return (out_channels == kernel_dim_o_val * kernel_dim_i_val); } } // namespace cmsisnn diff --git a/src/relay/backend/contrib/cmsisnn/utils.h b/src/relay/backend/contrib/cmsisnn/convolutions.h similarity index 78% rename from src/relay/backend/contrib/cmsisnn/utils.h rename to src/relay/backend/contrib/cmsisnn/convolutions.h index deb704fe038e..e635702bf353 100644 --- a/src/relay/backend/contrib/cmsisnn/utils.h +++ b/src/relay/backend/contrib/cmsisnn/convolutions.h @@ -18,12 +18,12 @@ */ /*! - * \file src/relay/backend/contrib/cmsisnn/utils.h - * \brief CMSIS-NN utility functions + * \file src/relay/backend/contrib/cmsisnn/convolutions.h + * \brief CMSIS-NN utility functions for Convolutions */ -#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_UTILS_H_ -#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_UTILS_H_ +#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_CONVOLUTIONS_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_CONVOLUTIONS_H_ #include #include @@ -49,12 +49,12 @@ namespace cmsisnn { * attributes */ -bool is_cmsisnn_depthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, - const Array& kernel_shape); +bool IsCMSISNNDepthwise(const Conv2DAttrs* conv2d_attrs, const Array& input_shape, + const Array& kernel_shape); } // namespace cmsisnn } // namespace contrib } // namespace relay } // namespace tvm -#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_UTILS_H_ +#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_CONVOLUTIONS_H_ diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index ab673b500cf2..297e6b7acea3 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -31,7 +31,7 @@ #include "../../../op/make_op.h" #include "../../../qnn/utils.h" #include "../../../transforms/pattern_utils.h" -#include "utils.h" +#include "convolutions.h" namespace tvm { namespace relay { @@ -112,7 +112,7 @@ class GenerateConstantsMutator : public MixedModeMutator { Array input_shape = conv2d_call->args[0]->type_as()->shape; Array kernel_shape = conv2d_call->args[1]->type_as()->shape; - if (!is_cmsisnn_depthwise(conv2d_attrs, input_shape, kernel_shape)) { + if (!IsCMSISNNDepthwise(conv2d_attrs, input_shape, kernel_shape)) { // Transpose weights: HWIO -> OHWI for Conv2D conv2d_kernel = ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, &new_conv2d_attrs); } diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index d60e50317e4c..d1d1d20d6e34 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -30,7 +30,7 @@ #include "../../../transforms/pattern_utils.h" #include "buffer_size.h" #include "compiler_attrs.h" -#include "utils.h" +#include "convolutions.h" namespace tvm { namespace relay { @@ -206,7 +206,7 @@ class RelayToTIRVisitor : public MixedModeMutator { int32_t output_c = qnn::get_const_int(output_shape[3]); int32_t depth_multiplier = -1; - if (is_cmsisnn_depthwise(conv2d_attrs, input_shape, filter_shape)) { + if (IsCMSISNNDepthwise(conv2d_attrs, input_shape, filter_shape)) { // Refer to TVM frontend to know how depth multiplier and out_channels are related // https://github.com/apache/tvm/blob/6ed3ab3e33f8eafa4acaf53b7a671831de7587e9/python/tvm/relay/frontend/tflite.py#L2129 int kernel_pos_i = kernel_layout.find("I"); diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 8ed0b01375a3..0b15c5a2466c 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -23,8 +23,13 @@ from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_models, compile_and_run - +from tvm.testing.aot import ( + generate_ref_data, + AOTTestModel, + compile_models, + compile_and_run, + run_and_check, +) from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from .utils import ( make_module, @@ -211,7 +216,7 @@ def test_conv2d_number_primfunc_args( cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"] assert ( len(cmsisnn_func.params) == expected_num_params - ), "Generated unexpected number of function arguments" + ), "Generated unexpected number of function arguments." @tvm.testing.requires_cmsisnn @@ -629,11 +634,13 @@ def test_relay_conv2d_cmsisnn_depthwise_int8( # validate pattern matching assert_partitioned_function(orig_mod, cmsisnn_mod) - # validate the output + # generate reference output rng = np.random.default_rng(12345) inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} output_list = generate_ref_data(orig_mod["main"], inputs, params) - compile_and_run( + + # validate presence of depthwise convolution + compiled_models = compile_models( AOTTestModel( module=cmsisnn_mod, inputs=inputs, @@ -641,9 +648,31 @@ def test_relay_conv2d_cmsisnn_depthwise_int8( params=params, output_tolerance=1, ), - test_runner, interface_api, use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + cmsisnn_tir_mod = None + for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items(): + if target.kind.name == "cmsis-nn": + cmsisnn_tir_mod = mod + + cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"] + call_extern = None + if isinstance(cmsisnn_func.body, tvm.tir.stmt.Evaluate): + call_extern = cmsisnn_func.body.value + else: + call_extern = cmsisnn_func.body.body.value + assert ( + call_extern.args[0].value == "arm_depthwise_conv_wrapper_s8" + ), "Relay Conv2D should be mapped to CMSIS-NN Depthwise Convolution." + + # validate the output + run_and_check( + models=compiled_models, + runner=test_runner, + interface_api=interface_api, )