From 6ce5edfd6a7c15b6dc7979f0c0468a1d26806f7e Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 22 Nov 2021 22:10:09 +0000 Subject: [PATCH] [microNPU] Add unary elementwise operator infrastructure with ABS (#9530) * [microNPU] Add unary elementwise operator infrastructure with ABS * Added unary elementwise ABS legalization support and tests * Added unary_elementwise Relay to TIR lowering and tests * Added TIR to Vela translation and tests * Added codegen tests Co-authored-by: Rishabh Jain --- .../relay/backend/contrib/ethosu/legalize.py | 91 +++++++++ .../backend/contrib/ethosu/op/__init__.py | 1 + .../contrib/ethosu/op/unary_elementwise.py | 163 ++++++++++++++++ .../backend/contrib/ethosu/te/__init__.py | 1 + .../contrib/ethosu/te/unary_elementwise.py | 126 ++++++++++++ .../backend/contrib/ethosu/tir/passes.py | 2 + .../relay/backend/contrib/ethosu/tir/spec.py | 19 ++ .../contrib/ethosu/tir/unary_elementwise.py | 74 +++++++ .../contrib/ethosu/tir_to_cs_translator.py | 51 +++++ .../tvm/relay/backend/contrib/ethosu/util.py | 38 +++- python/tvm/relay/op/contrib/ethosu.py | 82 +++++++- .../op/contrib/ethosu/binary_elementwise.cc | 4 +- src/relay/op/contrib/ethosu/common.cc | 5 +- src/relay/op/contrib/ethosu/common.h | 5 +- .../op/contrib/ethosu/unary_elementwise.cc | 183 ++++++++++++++++++ tests/python/contrib/test_ethosu/infra.py | 28 +++ .../contrib/test_ethosu/test_codegen.py | 76 ++++++++ .../contrib/test_ethosu/test_legalize.py | 102 ++++++++++ .../test_replace_unary_elementwise.py | 155 +++++++++++++++ .../test_ethosu/test_type_inference.py | 57 +++++- 20 files changed, 1236 insertions(+), 27 deletions(-) create mode 100644 python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py create mode 100644 src/relay/op/contrib/ethosu/unary_elementwise.cc create mode 100644 tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 8095cb184f5b..274f148e9134 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -741,6 +741,96 @@ def __call__(self, *args, **kwargs): pass +class UnaryElementwiseRewriter(DFPatternCallback): + """ + Convert ethosu unary elementwise composite function to + ethosu_unary_elementwise operators + """ + + def __init__(self, params_class: Type, pattern: CallPattern): + super().__init__(require_type=True) + self.params_class = params_class + self.pattern = pattern + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[0] + + if str(params.ofm.layout) != "NHWC": + raise UnsupportedLayout(str(params.ofm.layout)) + + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + # We don't yet support activation functions that use LUT. + lut = relay.const([], dtype="int8") + + unary_input_shape = params.ifm.shape + # If the input tensor is not 4D, enter reshapes before and after the unary operator + if len(params.ifm.shape) == 4: + unary_input = params.ifm.tensor + else: + pad_size = 4 - len(unary_input_shape) + unary_input_shape = ([1] * pad_size) + unary_input_shape + unary_input = relay.op.reshape(params.ifm.tensor, newshape=unary_input_shape) + + ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise( + ifm=unary_input, + lut=lut, + operator_type=params.operator_type, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ofm_channels=unary_input_shape[3], + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ) + if len(params.ifm.shape) == 4: + op = ethosu_unary_elementwise + else: + op = relay.op.reshape(ethosu_unary_elementwise, newshape=params.ifm.shape) + return op + + +class AbsRewriter(UnaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.AbsParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AbsParams.composite_name}))( + wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeAbs: + """This is the pass that wraps the AbsRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(AbsRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -765,6 +855,7 @@ def transform_module( mod = LegalizeMin()(mod) mod = LegalizeMax()(mod) mod = LegalizeShl()(mod) + mod = LegalizeAbs()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py index 13e6fc9e7a01..8d51c8a5abea 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -21,3 +21,4 @@ from .pooling import ethosu_pooling from .binary_elementwise import ethosu_binary_elementwise from .identity import ethosu_identity +from .unary_elementwise import ethosu_unary_elementwise diff --git a/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py new file mode 100644 index 000000000000..a339561d97e3 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/unary_elementwise.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operator for unary elementwise operations for Arm(R) Ethos(TM)-U NPU""" +from typing import Optional +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import unary_elementwise_compute + + +def _extract_ethosu_unary_elementwise_params(attrs, args): + """Get the parameters necessary to construct a ethosu_unary_elementwise compute TE + from a ethosu_unary_elementwise Relay call.""" + ifm = args[0] + lut = args[1] + operator_type = attrs.operator_type + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + ofm_channels = attrs.ofm_channels + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + rounding_mode = attrs.rounding_mode + ifm_layout = attrs.ifm_layout + ofm_layout = attrs.ofm_layout + + return ( + ifm, + lut, + operator_type, + ifm_scale, + ifm_zero_point, + ofm_scale, + ofm_zero_point, + ofm_channels, + activation, + clip_min, + clip_max, + rounding_mode, + ifm_layout, + ofm_layout, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMCompute") +def create_ethosu_unary_elementwise_compute(attrs, args, out_type): + """Create an ethosu_unary_elementwise compute op.""" + params = _extract_ethosu_unary_elementwise_params(attrs, args) + op = unary_elementwise_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.unary_elementwise", "FTVMStrategy") +def unary_elementwise_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_unary_elementwise_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_unary_elementwise", + ) + return strategy + + +def ethosu_unary_elementwise( + ifm: tvm.relay.Expr, + lut: tvm.relay.Expr, + operator_type: str, + ifm_scale: float, + ifm_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + ofm_channels: int, + activation: Optional[str] = "NONE", + clip_min: Optional[int] = 0, + clip_max: Optional[int] = 0, + rounding_mode: Optional[str] = "TFL", + ifm_layout: Optional[str] = "NHWC", + ofm_layout: Optional[str] = "NHWC", +) -> tvm.relay.Call: + """This is a quantized unary elementwise operation as supported by the + NPU. It accepts either NHWC or NHCWB16 format for the input data. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + lut : tvm.relay.Expr + The look-up table values to use if activation = "LUT". + operator_type: str + The type of the unary elementwise operator. + "ABS" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + ofm_channels : int + The number of OFM channels. + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int, optional + The minimum clipping value if activation = "CLIP". + clip_max : int, optional + The maximum clipping value if activation = "CLIP". + rounding_mode : str, optional + The rounding mode to apply to the Output Feature Map tensor. + "TFL" - Tensorflow Lite rounding scheme. + "TRUNCATE" - Truncate towards zero. + "NATURAL" - Round to nearest value, with x.5 rounded up towards +infinity. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_binary_elementwise op. + """ + return _make.ethosu_unary_elementwise( + ifm, + lut, + operator_type, + ifm_scale, + ifm_zero_point, + ofm_scale, + ofm_zero_point, + ofm_channels, + activation, + clip_min, + clip_max, + rounding_mode, + ifm_layout, + ofm_layout, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index a2d1526372c9..21261521ac5f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -21,3 +21,4 @@ from .pooling import * from .binary_elementwise import * from .identity import * +from .unary_elementwise import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py new file mode 100644 index 000000000000..d45a8f4fc43d --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for unary_elementwise for the NPU""" + +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def unary_elementwise_compute( + ifm: te.Tensor, + lut: te.Tensor, + operator_type: str, + ifm_scale: float, + ifm_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + ofm_channels: int, + activation: str, + clip_min: int, + clip_max: int, + rounding_mode: str, + ifm_layout: str, + ofm_layout: str, +) -> te.Tensor: + """A compute operator representing the capabilities of unary_elementwise for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + lut : te.Tensor + The look-up table values to use if activation = "LUT". + operator_type: str + The type of the unary elementwise operator. + "ABS" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + ofm_channels : int + The number of OFM channels. + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + rounding_mode : str + The rounding mode to apply to the Output Feature Map tensor. + "TFL" - Tensorflow Lite rounding scheme. + "TRUNCATE" - Truncate towards zero. + "NATURAL" - Round to nearest value, with x.5 rounded up towards +infinity. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + te.Tensor + The OFM tensor. + + """ + assert ifm.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + # Changing the ifm and ofm scale to conform with that expected by Vela API + ofm_scale = ifm_scale / ofm_scale + ifm_scale = 1.0 + + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute( + ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, (0, 0, 0, 0) + ) + + # Unary elementwise compute operation + ofm_height = dmaed_ifm.shape[1] + ofm_width = dmaed_ifm.shape[2] + + unary_elementwise_attrs = { + "op": "ethosu_unary_elementwise", + "operator_type": operator_type, + "activation": activation, + "clip_min": clip_min, + "clip_max": clip_max, + "rounding_mode": rounding_mode, + } + + operators = {"ABS": te.abs} + + unary_elementwise = te.compute( + (1, ofm_height, ofm_width, ofm_channels), + lambda nn, hh, ww, cc: operators[operator_type]( + dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype) + ), + name="ethosu_unary_elementwise", + attrs=unary_elementwise_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(unary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index b070b11c0bf5..cb46ba319edd 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -25,6 +25,7 @@ from .pooling import get_pooling_params from .binary_elementwise import get_binary_elementwise_params from .identity import get_identity_params +from .unary_elementwise import get_unary_elementwise_params from .transform import get_copy_params from .utils import get_weights_pointer, get_scale_bias_pointer @@ -60,6 +61,7 @@ def ReplaceOperators(): "ethosu_pooling": get_pooling_params, "ethosu_binary_elementwise": get_binary_elementwise_params, "ethosu_identity": get_identity_params, + "ethosu_unary_elementwise": get_unary_elementwise_params, } pointer_to_producer = {} pointer_to_consumer = {} diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index 6201b1a38b18..f9d38df9d901 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -290,3 +290,22 @@ def __init__( self.reversed_operands = reversed_operands self.activation = activation self.rounding_mode = rounding_mode + + +class SerialUnaryElementwise(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.unary_elementwise tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + operator_type: str, + activation: SerialActivation, + rounding_mode: str, + ): + self.ifm = ifm + self.ofm = ofm + self.operator_type = operator_type + self.activation = activation + self.rounding_mode = rounding_mode diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py new file mode 100644 index 000000000000..6dc801f2b28c --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the unary_elementwise operators in TIR.""" +from tvm import tir +from .utils import get_outer_loops, get_op_attrs +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialActivation, SerialUnaryElementwise + + +def get_unary_elementwise_params(stmt, producers, consumers): + """Get the parameters necessary to construct a call_extern for a unary_elementwise. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a unary elementwise loop nest. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialUnaryElementwise + The parameters needed to construct a unary elementwise operator. + output_pointer : tvm.tir.Var + The output pointer of the unary elementwise operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the unary elementwise output pointer. + + """ + attrs, body = get_op_attrs(stmt) + + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + input_pointer = None + if isinstance(inner.value, tir.expr.Select): + input_pointer = inner.value.condition.b.buffer_var + output_pointer = inner.buffer_var + # Get feature map info + serial_ifm, _ = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + SerialUnaryElementwise( + ifm=serial_ifm, + ofm=serial_ofm, + operator_type=attrs["operator_type"], + activation=serial_activation, + rounding_mode=attrs["rounding_mode"], + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index b8e79e7dae73..d276417bde3b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -309,6 +309,7 @@ def translate_ethosu_tir_call_extern(tir_call_extern): "ethosu_pooling": translate_ethosu_pooling, "ethosu_binary_elementwise": translate_ethosu_binary_elementwise, "ethosu_identity": translate_ethosu_pooling, + "ethosu_unary_elementwise": translate_ethosu_unary_elementwise, } ext_call_type = tir_call_extern.args[0].value assert ext_call_type in supported_call_extern.keys(), f"{ext_call_type} is not yet supported" @@ -770,3 +771,53 @@ def _create_npu_op_binary_elementwise(serial_binary_elementwise: spec.SerialBina npu_binary_elementwise_op.block_config = block_config return npu_binary_elementwise_op + + +def translate_ethosu_unary_elementwise( + tir_extern_call: tvm.tir.Call, +) -> vapi.NpuElementWiseOperation: + + """This function will translate a tir extern_call + as produced by Relay to TIR compilation. + Parameters + ---------- + tir_extern_call : tvm.tir.Call + This should be a tir external call that has a agreed upon ordering + for the NPU TIR Compiler. See SerialUnaryElementwise in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuElementWiseOperation + The vela object containing the params of ethosu_unary_elementwise + """ + serial_object = spec.create_serial_object(spec.SerialUnaryElementwise, tir_extern_call.args[1:]) + return _create_npu_op_unary_elementwise(serial_object) + + +def _create_npu_op_unary_elementwise(serial_unary_elementwise): + operator_type = serial_unary_elementwise.operator_type + if operator_type == "ABS": + op = vapi.NpuElementWiseOp.ABS + + npu_unary_elementwise_op = vapi.NpuElementWiseOperation(op) + npu_unary_elementwise_op.ifm = _create_npu_feature_map(serial_unary_elementwise.ifm) + npu_unary_elementwise_op.ofm = _create_npu_feature_map(serial_unary_elementwise.ofm) + + npu_unary_elementwise_op.activation = _create_npu_activation( + serial_unary_elementwise.activation + ) + if ( + npu_unary_elementwise_op.activation + and npu_unary_elementwise_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_unary_elementwise_op) + + npu_unary_elementwise_op.rounding_mode = _create_npu_rounding_mode( + serial_unary_elementwise.rounding_mode + ) + target_accel_type = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_unary_elementwise_op, target_accel_type) + npu_unary_elementwise_op.block_config = block_config + + return npu_unary_elementwise_op diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 370821aefa7e..589ab21b3998 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -80,14 +80,36 @@ class BinaryElementwiseArgs(Enum): of binary elementwise arguments """ - ifm = 0 - ifm2 = 1 - ifm_scale = 2 - ifm_zero_point = 3 - ifm2_scale = 4 - ifm2_zero_point = 5 - ofm_scale = 6 - ofm_zero_point = 7 + IFM = 0 + IFM2 = 1 + IFM_SCALE = 2 + IFM_ZERO_POINT = 3 + IFM2_SCALE = 4 + IFM2_ZERO_POINT = 5 + OFM_SCALE = 6 + OFM_ZERO_POINT = 7 + + +class QuantizeArgs(Enum): + """ + This is a helper enums to access the correct index of + quantize arguments + """ + + IFM = 0 + OFM_SCALE = 1 + OFM_ZERO_POINT = 2 + + +class DequantizeArgs(Enum): + """ + This is a helper enums to access the correct index of + dequantize arguments + """ + + IFM = 0 + IFM_SCALE = 1 + IFM_ZERO_POINT = 2 def is_composite_func(func: relay.Function, name: str) -> bool: diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 8b4ee21d2892..f37fcf6f97f4 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -41,6 +41,8 @@ from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs from tvm.relay.backend.contrib.ethosu.util import RequantArgs from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs from tvm.relay.backend.contrib.ethosu.util import get_dim_value except ImportError: vapi = None @@ -481,30 +483,30 @@ def __init__(self, func_body: Call, operator_type: str, has_quantization_paramet if has_quantization_parameters: self.ifm = TensorParams( - binary_op.args[BinaryElementwiseArgs.ifm.value], + binary_op.args[BinaryElementwiseArgs.IFM.value], layout, - binary_op.args[BinaryElementwiseArgs.ifm_scale.value], - binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value], + binary_op.args[BinaryElementwiseArgs.IFM_SCALE.value], + binary_op.args[BinaryElementwiseArgs.IFM_ZERO_POINT.value], ) self.ifm2 = TensorParams( - binary_op.args[BinaryElementwiseArgs.ifm2.value], + binary_op.args[BinaryElementwiseArgs.IFM2.value], layout, - binary_op.args[BinaryElementwiseArgs.ifm2_scale.value], - binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value], + binary_op.args[BinaryElementwiseArgs.IFM2_SCALE.value], + binary_op.args[BinaryElementwiseArgs.IFM2_ZERO_POINT.value], ) self.ofm = TensorParams( binary_op, layout, - binary_op.args[BinaryElementwiseArgs.ofm_scale.value], - binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value], + binary_op.args[BinaryElementwiseArgs.OFM_SCALE.value], + binary_op.args[BinaryElementwiseArgs.OFM_ZERO_POINT.value], ) else: self.ifm = TensorParams( - binary_op.args[BinaryElementwiseArgs.ifm.value], + binary_op.args[BinaryElementwiseArgs.IFM.value], layout, ) self.ifm2 = TensorParams( - binary_op.args[BinaryElementwiseArgs.ifm2.value], + binary_op.args[BinaryElementwiseArgs.IFM2.value], layout, ) self.ofm = TensorParams( @@ -852,6 +854,61 @@ def strided_slice_pattern(): return pattern +class AbsParams: + """ + This class will parse a call to a ethosu.unary_elementwise Abs composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.abs" + + def __init__(self, func_body: Call): + quantize = func_body + abs_op = quantize.args[0] + dequantize = abs_op.args[0] + + layout = "NHWC" + + self.ifm = TensorParams( + dequantize.args[DequantizeArgs.IFM.value], + layout, + dequantize.args[DequantizeArgs.IFM_SCALE.value], + dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + quantize, + layout, + quantize.args[QuantizeArgs.OFM_SCALE.value], + quantize.args[QuantizeArgs.OFM_ZERO_POINT.value], + ) + + self.operator_type = "ABS" + self.activation = None + + def is_valid(self): + """Checks whether Abs has compatible attributes with HW""" + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8, np.uint8]): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_dimensions(self.ifm): + return False + if len(self.ifm.shape) == 4 and self.ifm.shape[0] != 1: + return False + if self.ifm.shape != self.ofm.shape: + return False + return True + + +def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """Create pattern for abs""" + pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + pattern = is_op("abs")(pattern) + pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant()) + return pattern + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -915,6 +972,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal strided_slice_pattern(), lambda pat: StridedSliceParams(pat).is_valid(), ), + ( + AbsParams.composite_name, + abs_pattern(), + lambda pat: AbsParams(pat).is_valid(), + ), ] diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc index a9376791595b..48b085a2b6f2 100644 --- a/src/relay/op/contrib/ethosu/binary_elementwise.cc +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -238,8 +238,8 @@ bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const } // Assign ofm type - auto ofm_shape = EthosuInferBinaryElementwiseOutputShape(ifm->shape, param->ifm_layout, - param->ofm_layout, param->ifm_channels); + auto ofm_shape = EthosuInferElementwiseOutputShape(ifm->shape, param->ifm_layout, + param->ofm_layout, param->ifm_channels); reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype)); return true; } diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index bdaa9da52618..817575cc8d0d 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -32,9 +32,8 @@ namespace op { namespace contrib { namespace ethosu { -Array EthosuInferBinaryElementwiseOutputShape(Array ifm_shape, - String ifm_layout, String ofm_layout, - IndexExpr ofm_channels) { +Array EthosuInferElementwiseOutputShape(Array ifm_shape, String ifm_layout, + String ofm_layout, IndexExpr ofm_channels) { // In the case of NHCWB16, convert the ifm shape to NHW (C not required for this function) if (ifm_layout == "NHCWB16") { ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]}; diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index 574fb91181ef..cc489de6a49a 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -40,9 +40,8 @@ namespace ethosu { * \param ofm_channels The number of Output Feature Map channels. * \return The shape of the output tensor. */ -Array EthosuInferBinaryElementwiseOutputShape(Array ifm_shape, - String ifm_layout, String ofm_layout, - IndexExpr ofm_channels); +Array EthosuInferElementwiseOutputShape(Array ifm_shape, String ifm_layout, + String ofm_layout, IndexExpr ofm_channels); /*! \brief Infer the output tensor shape for convolution and pooling operators. * \param ifm_shape The shape of Input Feature Map. diff --git a/src/relay/op/contrib/ethosu/unary_elementwise.cc b/src/relay/op/contrib/ethosu/unary_elementwise.cc new file mode 100644 index 000000000000..60f1eefaa6b2 --- /dev/null +++ b/src/relay/op/contrib/ethosu/unary_elementwise.cc @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/unary_elementwise.cc + * \brief Property def of the Arm(R) Ethos(TM)-U unary elementwise ops. + */ +#include + +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the NPU unary elementwise operator */ +struct EthosuUnaryElementwiseAttrs : public tvm::AttrsNode { + String operator_type; + double ifm_scale; + int ifm_zero_point; + double ofm_scale; + int ofm_zero_point; + IndexExpr ofm_channels; + String activation; + int clip_min; + int clip_max; + String rounding_mode; + String ifm_layout; + String ofm_layout; + + TVM_DECLARE_ATTRS(EthosuUnaryElementwiseAttrs, "relay.attrs.EthosuUnaryElementwiseAttrs") { + TVM_ATTR_FIELD(operator_type) + .describe( + "The type of the unary elementwise operator." + "'ABS'"); + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_channels).describe("The number of OFM channels."); + TVM_ATTR_FIELD(activation) + .describe( + "The activation function to use. " + "'NONE' - no activation function. " + "'CLIP' - clip the output between clip_min and clip_max. " + "'TANH' - tanh activation function. " + "'SIGMOID' - sigmoid activation function. " + "'LUT' - use a look-up table to perform the activation function.") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(rounding_mode) + .describe( + "The rounding mode to apply to the Output Feature Map tensor. " + "'TFL' - Tensorflow Lite rounding scheme. " + "'TRUNCATE' - Truncate towards zero." + "'NATURAL' - Round to nearest value, with x.5 rounded up towards +infinity.") + .set_default("TFL"); + TVM_ATTR_FIELD(ifm_layout) + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_layout) + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuUnaryElementwiseAttrs); + +bool EthosuUnaryElementwiseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const int ifm_index = 0; + const int result_index = 2; + ICHECK_EQ(types.size(), result_index + 1); + + const auto* ifm = types[ifm_index].as(); + if (ifm == nullptr) return false; + + const auto* param = attrs.as(); + CHECK(param != nullptr) << "EthosuUnaryElementwiseAttrs cannot be nullptr."; + + String operator_type = param->operator_type; + if (operator_type != "ABS") { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_unary_elementwise 'ABS' for operator_type but was" + << operator_type); + return false; + } + + auto ifm_dtype = ifm->dtype; + if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_unary_elementwise input data type " + << "of type(uint8) or type(int8) but was " << ifm_dtype); + return false; + } + + // Assign ofm type + auto ofm_shape = EthosuInferElementwiseOutputShape(ifm->shape, param->ifm_layout, + param->ofm_layout, param->ofm_channels); + reporter->Assign(types[result_index], TensorType(ofm_shape, ifm_dtype)); + return true; +} + +Expr MakeEthosuUnaryElementwise(Expr ifm, Expr lut, String operator_type, double ifm_scale, + int ifm_zero_point, double ofm_scale, int ofm_zero_point, + IndexExpr ofm_channels, String activation, int clip_min, + int clip_max, String rounding_mode, String ifm_layout, + String ofm_layout) { + auto attrs = make_object(); + + attrs->operator_type = std::move(operator_type); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->ofm_channels = std::move(ofm_channels); + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->rounding_mode = std::move(rounding_mode); + attrs->ifm_layout = std::move(ifm_layout); + attrs->ofm_layout = std::move(ofm_layout); + + static const Op& op = Op::Get("contrib.ethosu.unary_elementwise"); + return Call(op, {ifm, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_unary_elementwise") + .set_body_typed(MakeEthosuUnaryElementwise); + +RELAY_REGISTER_OP("contrib.ethosu.unary_elementwise") + .describe(R"code(Quantized unary elementwise operator for Arm(R) Ethos(TM)-U NPUs. + +This Relay operator corresponds to the hardware-implemented quantized +unary elementwise operation found on NPUs. It accepts either NHWC +or NHCWB16 format for the inputs data (input feature maps, or IFMs). + +Reference: https://developer.arm.com/documentation/102420/0200/ + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ofm**: (1, ofm_height, ofm_width, ofm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuUnaryElementwise", EthosuUnaryElementwiseRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index ecd404aa2d08..1c0b78cebf92 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -623,3 +623,31 @@ def make_ethosu_identity( activation=activation, ) return identity + + +def make_ethosu_unary_elementwise( + ifm, + ofm_channels, + operator_type, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + rounding_mode="TFL", +): + ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise( + ifm=ifm, + lut=relay.const([], dtype="int8"), + operator_type=operator_type, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ofm_channels=ofm_channels, + activation=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + rounding_mode=rounding_mode, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return ethosu_unary_elementwise diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 93af66da8194..5f4f4b17c755 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -765,5 +765,81 @@ def test_relay_strided_slice_codegen(ifm_shape, begin, end, accel_type): infra.verify_source(compiled_model, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("operator_type", ["ABS"]) +@pytest.mark.parametrize( + "ifm_shape", + [[1, 5, 12, 4], [1, 1, 2], [4, 3, 2], [10, 20], [345]], +) +def test_ethosu_unary_elementwise( + accel_type, + operator_type, + ifm_shape, +): + dtype = "int8" + + def get_tflite_graph(): + class Model(tf.Module): + @tf.function + def abs_func(self, x): + if operator_type == "ABS": + op = tf.math.abs(x) + return op + + model = Model() + + concrete_func = model.abs_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32) * 2 - 1] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = get_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 8c3e4e31c1ca..12bdddc978e3 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -890,5 +890,107 @@ def test_relay_strided_slice_legalize(ifm_shape, begin, end): assert list(identity.checked_type.shape) == slice_shape +@pytest.mark.parametrize("operator_type", ["ABS"]) +@pytest.mark.parametrize( + "ifm_shape", + [[1, 2, 3, 4], [1, 7, 3], [8, 3, 1], [11, 22], [300]], +) +def test_tflite_unary_elemwise_legalize( + operator_type, + ifm_shape, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def abs_func(self, x): + if operator_type == "ABS": + op = tf.math.abs(x) + return op + + model = Model() + + # Save the model + concrete_func = model.abs_func.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + out_shape = ifm_shape + func_body = ext_func.body + + # If we legalized the unary elementwise op into 4D + if func_body.op.name == "reshape": + reshape = func_body + unary = func_body.args[0] + reshape2 = unary.args[0] + + # Check the input to the reshape + reshape2_in_shape = [i for i in reshape2.args[0].checked_type.shape] + assert reshape2_in_shape == ifm_shape + + # Check that the unary elementwise operator is 4D after reshape + assert len(unary.checked_type.shape) == 4 + assert unary.args[0].checked_type.dtype == dtype + + # Check that the output of the graph has the same shape as input + reshape_out_shape = [i for i in reshape.checked_type.shape] + assert reshape_out_shape == ifm_shape + assert unary.attrs.operator_type == operator_type + + else: + unary = func_body + + # Check the IFM + assert list(unary.args[0].checked_type.shape) == ifm_shape + assert unary.args[0].checked_type.dtype == dtype + + # Check the OFM + assert list(unary.checked_type.shape) == out_shape + assert unary.checked_type.dtype == dtype + + # operator type check + assert unary.attrs.operator_type == operator_type + + if operator_type == "ABS": + rewriter = legalize.AbsRewriter() + pattern_table = [ + ( + ethosu.AbsParams.composite_name, + ethosu.abs_pattern(), + lambda pat: ethosu.AbsParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + mod = partition_ethosu_by_table(mod, pattern_table) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py new file mode 100644 index 000000000000..eff81c4e6cbd --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_unary_elementwise.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +import tvm.script +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir import spec +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_unary_elementwise + + +def _get_unary_elementwise_args(call, include_buffers=False, remove_constants=False): + args = call.args + unary_elementwise_args = [] + + for i, arg in enumerate(args): + if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + unary_elementwise_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + unary_elementwise_args.append(arg.index) + else: + unary_elementwise_args.append(arg) + + return unary_elementwise_args + + +@pytest.mark.parametrize( + "ifm_shape, ifm_channels, ifm_layout, ofm_layout, rounding_mode", + [ + ((1, 5, 9, 3), 3, "NHWC", "NHWC", "TFL"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16", "NATURAL"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC", "TRUNCATE"), + ((1, 8, 9, 40), 40, "NHWC", "NHCWB16", "TFL"), + ], +) +@pytest.mark.parametrize("operator_type", ["ABS"]) +@pytest.mark.parametrize("activation", ["NONE"]) +def test_unary_elementwise_single( + ifm_shape, + ifm_channels, + ifm_layout, + ofm_layout, + rounding_mode, + operator_type, + activation, +): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + + unary_elementwise = make_ethosu_unary_elementwise( + ifm, ifm_channels, operator_type, activation, ifm_layout, ofm_layout, rounding_mode + ) + func = relay.Function(relay.analysis.free_vars(unary_elementwise), unary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(_get_unary_elementwise_args(stmt, remove_constants=True)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1 + ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1 + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[2] + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[3] + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ifm_channels if ofm_width > 1 else 1 + ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1) + + serial_unary_elementwise = spec.SerialUnaryElementwise( + ifm=spec.SerialFeatureMap( + data_type="int8", + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ifm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type="int8", + height=ofm_height, + width=ofm_width, + channels=ifm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + operator_type=operator_type, + activation=spec.SerialActivation( + op=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ), + rounding_mode=rounding_mode, + ) + + assert data[0] == ["ethosu_unary_elementwise"] + list(serial_unary_elementwise) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 8d10d89f6afa..778e4efc4b24 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -25,6 +25,7 @@ from .infra import make_ethosu_pooling from .infra import make_ethosu_binary_elementwise from .infra import make_ethosu_identity +from .infra import make_ethosu_unary_elementwise @pytest.mark.parametrize( @@ -364,7 +365,7 @@ def test_ethosu_identity_invalid_shape(): run_opt_pass(func, relay.transform.InferType()) -def test_ethosu_invalid_dtype(): +def test_ethosu_identity_invalid_dtype(): invalid_dtype = "int32" ifm = relay.var("ifm", shape=[6000], dtype=invalid_dtype) @@ -374,5 +375,59 @@ def test_ethosu_invalid_dtype(): run_opt_pass(func, relay.transform.InferType()) +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] +) +def test_ethosu_unary_elementwise_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + operator_type = "ABS" + ofm_channels = 33 + unary_elementwise = make_ethosu_unary_elementwise( + ifm, + ofm_channels, + operator_type, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + f = relay.Function([ifm], unary_elementwise) + f = run_opt_pass(f, relay.transform.InferType()) + assert tuple(f.body.checked_type.shape) == ofm_shape + + +def test_ethosu_unary_elementwise_invalid_operator_type(): + ifm = relay.var("ifm", shape=(1, 3, 7, 12), dtype="int8") + invalid_op_type = "ABBBS" + unary_elementwise = make_ethosu_unary_elementwise( + ifm, + 12, + invalid_op_type, + ) + func = relay.Function([ifm], unary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +def test_ethosu_unary_elementwise_invalid_dtype(): + invalid_dtype = "int32" + ifm = relay.var("ifm", shape=(1, 5, 15, 25), dtype=invalid_dtype) + + unary_elementwise = make_ethosu_unary_elementwise( + ifm, + 25, + "ABS", + ) + func = relay.Function([ifm], unary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + if __name__ == "__main__": pytest.main([__file__])