diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index c5213fe07471..f0280a90c604 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -75,6 +75,18 @@ struct QuantizeAttrs : public tvm::AttrsNode { } }; +struct SimulatedQuantizeAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { + TVM_ATTR_FIELD(axis) + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + } +}; + /*! \brief Attribute for dequantize operator */ struct DequantizeAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py index 6d66e12eeafc..848409360a9d 100644 --- a/python/tvm/relay/qnn/op/__init__.py +++ b/python/tvm/relay/qnn/op/__init__.py @@ -19,4 +19,4 @@ from __future__ import absolute_import as _abs from .qnn import * from .op import register_qnn_legalize -from . import legalizations, layout_conversions +from . import _qnn, legalizations, layout_conversions diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py new file mode 100644 index 000000000000..a059c293a0f8 --- /dev/null +++ b/python/tvm/relay/qnn/op/_qnn.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, len-as-condition +"""QNN operator feature registration""" + +from tvm import topi + +from ...op.op import register_compute +from ...op.op import register_injective_schedule +from ...op.op import register_pattern, OpPattern + + +@register_compute("qnn.simulated_quantize") +def simulated_quantize_compute(attrs, inputs, output_type): + assert len(inputs) == 4 + return [ + topi.nn.simulated_quantize( + inputs[0], inputs[1], inputs[2], inputs[3], axis=attrs.get_int("axis") + ) + ] + + +register_injective_schedule("qnn.simulated_quantize") +register_pattern("qnn.simulated_quantize", OpPattern.ELEMWISE) + + +@register_compute("qnn.simulated_dequantize") +def simulated_dequantize_compute(attrs, inputs, output_type): + assert len(inputs) == 4 + return [ + topi.nn.simulated_dequantize( + inputs[0], inputs[1], inputs[2], inputs[3], axis=attrs.get_int("axis") + ) + ] + + +register_injective_schedule("qnn.simulated_dequantize") +register_pattern("qnn.simulated_dequantize", OpPattern.ELEMWISE) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index a5892f331f06..f02f8227e14a 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,8 +18,10 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs +from tvm import relay from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.utils import get_pad_tuple2d +from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE from . import _make from ... import op as reg from ...op import OpPattern @@ -118,6 +120,40 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): return _make.quantize(data, output_scale, output_zero_point, axis, out_dtype) +def simulated_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): + r"""Simulated Quantize op + Mimics the quantize op but has more flexibility in valid inputs and always + outputs the same type as the input. This can be useful for + calibrating or training a quantized network. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be quantized. Can be of type float32. + output_zero_point : tvm.relay.Expr + The output zero_point. + output_scale : tvm.relay.Expr + The output scale. + axis : int + The channel axis for quantization. Default value is -1 which corresponds to the last axis. + out_dtype : string or tvm.relay.Expr + A string or tensor indicating which datatype to quantize to. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + # Convert string dtype to a constant if needed. + if isinstance(out_dtype, str): + type_code = SQNN_DTYPE_TO_CODE[out_dtype] + out_dtype = relay.const(type_code, dtype="int32") + # Wrap reshapes around qnn parameter tensors to guarantee shape compatibility. + output_scale = relay.op.reshape(output_scale, [-1]) + output_zero_point = relay.op.reshape(output_zero_point, [-1]) + return _make.simulated_quantize(data, out_dtype, output_scale, output_zero_point, axis) + + def dequantize(data, input_scale, input_zero_point, axis=-1): r"""Dequantize op This operator takes quantized int8 and unit8 as input and produces @@ -127,7 +163,7 @@ def dequantize(data, input_scale, input_zero_point, axis=-1): Parameters ---------- data : tvm.relay.Expr - The input tensor to be dequantized. Can be of type [int8, uint8]. + The input tensor to be dequantized. Can be of type [int8, uint8, int32]. input_zero_point : tvm.relay.Expr The input zero_point. input_scale : tvm.relay.Expr @@ -143,6 +179,40 @@ def dequantize(data, input_scale, input_zero_point, axis=-1): return _make.dequantize(data, input_scale, input_zero_point, axis) +def simulated_dequantize(data, input_scale, input_zero_point, axis=-1, in_dtype="int8"): + r"""Simulated Dequantize op + Mimics the dequantize op but has more flexibility in valid inputs and always + outputs the same type as the input. This can be useful for calibrating or + training a quantized network. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor to be dequantized. + input_zero_point : tvm.relay.Expr + The input zero_point. + input_scale : tvm.relay.Expr + The input scale. + axis : int + The channel axis for quantization. Default value is -1 which corresponds to the last axis. + in_dtype : string or tvm.relay.Expr + A string or tensor indicating which datatype to dequantize from. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + # Convert string dtype to a constant if needed. + if isinstance(in_dtype, str): + type_code = SQNN_DTYPE_TO_CODE[in_dtype] + in_dtype = relay.const(type_code, dtype="int32") + # Wrap reshapes around qnn parameter tensors to guarantee shape compatibility. + input_scale = relay.op.reshape(input_scale, [-1]) + input_zero_point = relay.op.reshape(input_zero_point, [-1]) + return _make.simulated_dequantize(data, in_dtype, input_scale, input_zero_point, axis) + + def concatenate(data, input_scales, input_zero_points, output_scale, output_zero_point, axis): """Concatenate the quantized input tensors along the given axis. diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index 2ebbd1d67bd1..94a5b30c9b76 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -36,6 +36,7 @@ from .conv2d_transpose import * from .conv1d_transpose import * from .bnn import * +from .qnn import * from .upsampling import * from .local_response_norm import * from .bitserial_conv2d import * diff --git a/python/tvm/topi/nn/qnn.py b/python/tvm/topi/nn/qnn.py new file mode 100644 index 000000000000..caed28580037 --- /dev/null +++ b/python/tvm/topi/nn/qnn.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Quantized Neural Network (QNN) Operators""" +import tvm +from tvm import te, tir, topi + +SQNN_DISABLE = 0 +SQNN_INT8 = 1 +SQNN_UINT8 = 2 +SQNN_INT32 = 3 + +SQNN_DTYPE_TO_CODE = { + "disable": SQNN_DISABLE, + "int8": SQNN_INT8, + "uint8": SQNN_UINT8, + "int32": SQNN_INT32, +} + +SQNN_CODE_TO_DTYPE = {v: k for k, v in SQNN_DTYPE_TO_CODE.items()} + + +@tvm.te.tag_scope(tag=topi.tag.ELEMWISE) +def simulated_quantize(data, out_dtype, output_scale=None, output_zero_point=None, axis=-1): + """Simulated QNN quantize operator that mimics QNN outputs without changing datatype. + The benefit of this operator over true QNN quantize is that this operator allows dynamic + datatype selection and can operate on both per-channel and scalar scales and zero points while + QNN quantize requires both of these to be fixed at compile time. + + Parameters + ---------- + data: tvm.te.Tensor + An N-D input tensor to the operator. + + out_dtype: tvm.te.Tensor + A scalar variable that indicates which datatype to simulate quantization with. Use + SQNN_DTYPE_TO_CODE to convert a dtype string into the corresponding variable + value. + + output_scale: tvm.te.Tensor, optional + A scalar tensor representing the scale to use when quantizing to integer datatypes. + When it contains more than a single value, N must match the number of channels in data. + + output_zero_point: tvm.te.Tensor, optional + A 1-D tensor representing the zero point to use when quantizing to integer datatypes. + When it contains more than a single value, N must match the number of channels in data. + + axis: int, optional + The channel axis for quantization. Default value is -1 which corresponds to the last axis. + + """ + # When disabled, just pass through the input values. + def _compute_pass_through(value, *indices): + return value[indices] + + # Simulate quantization for arbitrary integer datatypes. The computation for all datatypes is: + # Q_output = clip((round(input_tensor/output_scale) + output_zero_point), + # out_dtype::min, + # out_dtype::max) + def _compute_intn(dtype, value, *indices): + assert output_scale is not None and output_zero_point is not None + const_min = tvm.tir.min_value(dtype) + const_max = tvm.tir.max_value(dtype) + # Use indexmod to handle both scalar and per-channel QNN parameters. + scale_idx = tir.indexmod(indices[axis], topi.shape(output_scale)[0]) + zp_idx = tir.indexmod(indices[axis], topi.shape(output_zero_point)[0]) + return te.max( + te.min( + te.round(value[indices] / output_scale[scale_idx]) + output_zero_point[zp_idx], + const_max, + ), + const_min, + ) + + # Use an if chain to dynamically return the proper quantization based on the input datatype. + # This allows the op to compile once but apply different quantization approaches + # using a variable datatype input. + def _dispatch_sim_quantize(value): + pass_through_value = te.compute( + data.shape, lambda *indices: _compute_pass_through(value, *indices) + ) + int8_value = te.compute( + data.shape, + lambda *indices: tir.if_then_else( + out_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]), + _compute_intn("int8", value, *indices), + pass_through_value[indices], + ), + ) + uint8_value = te.compute( + data.shape, + lambda *indices: tir.if_then_else( + out_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]), + _compute_intn("uint8", value, *indices), + int8_value[indices], + ), + ) + int32_value = te.compute( + data.shape, + lambda *indices: tir.if_then_else( + out_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]), + _compute_intn("int32", value, *indices), + uint8_value[indices], + ), + ) + + return int32_value + + return te.compute(data.shape, lambda *indices: _dispatch_sim_quantize(data)[indices]) + + +@tvm.te.tag_scope(tag=topi.tag.ELEMWISE) +def simulated_dequantize(data, in_dtype, input_scale=None, input_zero_point=None, axis=-1): + """Simulated QNN dequantize operator that mimics QNN outputs without changing datatype. + The benefit of this operator over true QNN dequantize is that this operator allows dynamic + datatype selection and can operate on both per-channel and scalar scales and zero points while + QNN dequantize requires both of these to be fixed at compile time. + + Parameters + ---------- + data: tvm.te.Tensor + An N-D input tensor to the operator. + + in_dtype: tvm.te.Tensor + A scalar variable that indicates which datatype to simulate dequantization with. Use + SQNN_DTYPE_TO_CODE to convert a dtype string into the corresponding variable + value. + + input_scale: tvm.te.Tensor, optional + A scalar tensor representing the scale to use when dequantizing from integer datatypes. + When it contains more than a single value, N must match the number of channels in data. + + input_zero_point: tvm.te.Tensor, optional + A 1-D tensor representing the zero point to use when dequantizing from integer datatypes. + When it contains more than a single value, N must match the number of channels in data. + + axis: int, optional + The channel axis for quantization. Default value is -1 which corresponds to the last axis. + + """ + # When disabled simply return the input tensor. + def _compute_pass_through(value, *indices): + return value[indices] + + # Simulate dequantization for arbitrary integer datatypes. The computation for all datatypes is: + # DQ_output = (input - zero_point) * scale + def _compute_intn(value, *indices): + assert input_scale is not None and input_zero_point is not None + # Use indexmod to handle both scalar and per-channel QNN parameters. + scale_idx = tir.indexmod(indices[axis], topi.shape(input_scale)[0]) + zp_idx = tir.indexmod(indices[axis], topi.shape(input_zero_point)[0]) + return (value[indices] - input_zero_point[zp_idx]) * input_scale[scale_idx] + + # Use an if chain to dynamically return the proper dequantization based on the input datatype. + # This allows the op to compile once but apply different quantization approaches + # using a variable datatype input. + def _dispatch_sim_dequantize(value): + pass_through_value = te.compute( + data.shape, lambda *indices: _compute_pass_through(value, *indices) + ) + intn_condition = tvm.te.any( + in_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]), + in_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]), + in_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]), + ) + intn_value = te.compute( + data.shape, + lambda *indices: tir.if_then_else( + intn_condition, + _compute_intn(value, *indices), + pass_through_value[indices], + ), + ) + + return intn_value + + return te.compute(data.shape, lambda *indices: _dispatch_sim_dequantize(data)[indices]) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 724441e0c523..b0fe9356a758 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -53,7 +53,7 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; + axis = (axis < 0) ? data->shape.size() + axis : axis; ICHECK_LT(axis, static_cast(data->shape.size())) << "axis " << dequantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; @@ -81,7 +81,7 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Array& types, const DequantizeAttrs* attrs) { - const auto axis = attrs->axis; + auto axis = attrs->axis; ICHECK_EQ(types.size(), 4); auto in_type = types[0]; @@ -92,6 +92,11 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, size_t n_dim = input_shape.size(); + // Wrap axis from negative to positive if needed. + if (axis < 0) { + axis = static_cast(n_dim) + axis; + } + // Expand scale and zero point if the input tensor is channel quantized auto expanded_input_scale = input_scale; if (!IsConstScalar(input_scale) && !IsScalarType(types[1])) { diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 9829834f43a3..751abfc5ca81 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -19,8 +19,8 @@ /*! * \file src/relay/qnn/op/quantize.cc - * \brief QNN dequantize operator. Dequantize operator converts from quantized - * domain to unquantized domain. + * \brief QNN quantize operator. Quantize operator converts from unquantized + * domain to quantized domain. */ #include @@ -51,7 +51,7 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; + axis = (axis < 0) ? data->shape.size() + axis : axis; ICHECK_LT(axis, static_cast(data->shape.size())) << "axis " << quantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; @@ -93,10 +93,15 @@ Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, Array input_shape = in_tensor_type->shape; const auto out_dtype = attrs->out_dtype; - const auto axis = attrs->axis; + auto axis = attrs->axis; size_t n_dim = input_shape.size(); + // Wrap axis from negative to positive if needed. + if (axis < 0) { + axis = static_cast(n_dim) + axis; + } + auto expanded_output_scale = output_scale; if (!IsConstScalar(output_scale) && !IsScalarType(types[1])) { expanded_output_scale = ExpandBiasToMatchAxis(output_scale, n_dim, {axis}); diff --git a/src/relay/qnn/op/simulated_dequantize.cc b/src/relay/qnn/op/simulated_dequantize.cc new file mode 100644 index 000000000000..e1fc47d700c9 --- /dev/null +++ b/src/relay/qnn/op/simulated_dequantize.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/simulated_dequantize.cc + * \brief QNN simulated dequantize operator. Mimics the behavior + * of QNN dequantize in floating point with added flexibility. + */ + +#include +#include +#include + +#include "../../transforms/pattern_utils.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace qnn { + +bool SimulatedDequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types = [data_type, datatype_type, scale_type, zp_type, ret_type] + ICHECK_EQ(types.size(), 5); + const auto* data = types[0].as(); + const auto* dtype = types[1].as(); + + if ((data == nullptr) || (dtype == nullptr)) { + return false; + } + + // assign output type + reporter->Assign(types[4], TensorType(data->shape, data->dtype)); + return true; +} + +Expr MakeSimulatedDequantize(Expr data, Expr in_dtype, Expr input_scale, Expr input_zero_point, + int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("qnn.simulated_dequantize"); + return Call(op, {data, in_dtype, input_scale, input_zero_point}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.simulated_dequantize") + .describe(R"code(Simulates the functionality of qnn.dequantize but allows more flexible + dynamic input type conversion and always operates on float values. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "Tensor", "The tensor to dequantize.") + .add_argument("in_dtype", "Tensor", + "A code corresponding to the type of quantization to convert from.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .set_support_level(11) + .add_type_rel("QNNSimulatedDequantize", SimulatedDequantizeRel); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.simulated_dequantize") + .set_body_typed(MakeSimulatedDequantize); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/src/relay/qnn/op/simulated_quantize.cc b/src/relay/qnn/op/simulated_quantize.cc new file mode 100644 index 000000000000..089762a6ade0 --- /dev/null +++ b/src/relay/qnn/op/simulated_quantize.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/simulated_quantize.cc + * \brief QNN simulated quantize operator. Mimics the behavior + * of QNN quantize in floating point with added flexibility. + */ + +#include +#include +#include + +#include "../../transforms/pattern_utils.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace qnn { + +TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); + +bool SimulatedQuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types = [data_type, datatype_type, scale_type, zp_type, ret_type] + ICHECK_EQ(types.size(), 5); + const auto* data = types[0].as(); + const auto* dtype = types[1].as(); + + if ((data == nullptr) || (dtype == nullptr)) { + return false; + } + + // assign output type + reporter->Assign(types[4], TensorType(data->shape, data->dtype)); + return true; +} + +Expr MakeSimulatedQuantize(Expr data, Expr out_dtype, Expr output_scale, Expr output_zero_point, + int axis) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("qnn.simulated_quantize"); + return Call(op, {data, out_dtype, output_scale, output_zero_point}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.simulated_quantize") + .describe(R"code(Simulates the functionality of qnn.quantize but allows more flexible + dynamic input type conversion and always outputs float values. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "Tensor", "The tensor to quantize.") + .add_argument("out_dtype", "Tensor", + "A code corresponding to the type of quantization to apply.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("QNNSimulatedQuantize", SimulatedQuantizeRel); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.simulated_quantize").set_body_typed(MakeSimulatedQuantize); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index e7fb161a13cb..1833458fdb75 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -98,7 +98,7 @@ def test_channelwise_axis_1(): } dequantize_test_driver( - in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=1 + in_dtype="uint8", quant_args=quant_args, in_data=data, verify_output_data=output, axis=-1 ) diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 2ef298679904..b300c5612174 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -127,7 +127,7 @@ def test_channelwise_axis_1(): quantize_test_driver( in_dtype="float32", quant_args=quant_args, - axis=1, + axis=-1, out_dtype="uint8", in_data=data, verify_output_data=output, diff --git a/tests/python/relay/test_op_qnn_simulated_dequantize.py b/tests/python/relay/test_op_qnn_simulated_dequantize.py new file mode 100644 index 000000000000..0cc04e4998eb --- /dev/null +++ b/tests/python/relay/test_op_qnn_simulated_dequantize.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime +from tvm.runtime.vm import VirtualMachine +from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE + + +def dequantize_test_driver(in_dtype, quant_args, axis, in_data): + shape = in_data.shape + input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + input_zero_point = relay.const(quant_args["in_zero_point"]) + input_scale = relay.const(quant_args["in_scale"]) + dequantized_output = relay.qnn.op.dequantize( + input_data, + input_scale=input_scale, + input_zero_point=input_zero_point, + axis=axis, + ) + mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output) + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input(input_data=in_data) + rt_mod.set_input(**params) + rt_mod.run() + res = rt_mod.get_output(0).asnumpy() + return res + + +def build_simulated_dequantize(input_data, scale, zp, dtype, axis=-1): + sim_q = relay.qnn.op.simulated_dequantize( + input_data, + scale, + zp, + axis=axis, + in_dtype=dtype, + ) + mod = tvm.IRModule.from_expr(sim_q) + with tvm.transform.PassContext(opt_level=3): + vm_exec = relay.vm.compile(mod, "llvm", params=None) + vm = VirtualMachine(vm_exec, tvm.cpu(0)) + return vm + + +def verify_simulated_dequantize_simple(dtype): + data = np.random.uniform(low=-128, high=127, size=[2, 5]).astype(dtype) + data_fp = data.astype("float32") + scale_np = np.float32(0.5) + zp_np = np.int32(127) + dtype_np = np.int32(SQNN_DTYPE_TO_CODE[dtype]) + quant_args = {"in_zero_point": zp_np, "in_scale": scale_np} + dq_out = dequantize_test_driver( + in_dtype=dtype, + quant_args=quant_args, + axis=-1, + in_data=data, + ) + input_data = relay.var("input_data", shape=data.shape, dtype="float32") + scale = relay.var("scale", shape=[]) + zp = relay.var("zp", shape=[]) + dtype = relay.var("dtype", shape=[]) + vm = build_simulated_dequantize(input_data, scale, zp, dtype) + sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_dq_out.asnumpy(), dq_out) + + +def test_simulated_dequantize(): + verify_simulated_dequantize_simple("uint8") + verify_simulated_dequantize_simple("int8") + verify_simulated_dequantize_simple("int32") + + +def test_dynamic_channels(): + # Compile simulated quantize once but support either per-channel or scalar params. + data = np.random.uniform(low=-64, high=64, size=[2, 5]).astype("int8") + data_fp = data.astype("float32") + # Test scalar qnn params. + scale_np = np.asarray([0.5]).astype("float32") + zp_np = np.asarray([0]).astype("int32") + dtype_np = np.int32(SQNN_DTYPE_TO_CODE["int8"]) + quant_args = {"in_zero_point": zp_np[0], "in_scale": scale_np[0]} + dq_out = dequantize_test_driver( + in_dtype="int8", + quant_args=quant_args, + axis=0, + in_data=data, + ) + # Create variables with undefined shape and run with scalar inputs. + input_data = relay.var("input_data", shape=data.shape, dtype="float32") + scale = relay.var("scale", shape=[relay.Any()], dtype="float32") + zp = relay.var("zp", shape=[relay.Any()], dtype="int32") + dtype = relay.var("dtype", shape=[]) + vm = build_simulated_dequantize(input_data, scale, zp, dtype, axis=0) + sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_dq_out.asnumpy(), dq_out) + + # Now get the perchannel quantize output and compare without recompiling. + scale_np = np.array([0.5, 0.25]).astype("float32") + zp_np = np.array([127, 123]).astype("int32") + + # Get the reference quantize output. + quant_args = {"in_zero_point": zp_np, "in_scale": scale_np} + dq_out = dequantize_test_driver( + in_dtype="int8", + quant_args=quant_args, + axis=0, + in_data=data, + ) + # Run the simulated quantize without recompiling and confirm results match. + sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_dq_out.asnumpy(), dq_out) + + +def test_dynamic_dtype(): + # Compile simulated quantize once but support any type of quantization. + data = np.random.uniform(low=0, high=255, size=[2, 5]).astype("uint8") + data_fp = data.astype("float32") + # Test scalar uint8 to fp32. + scale_np = np.asarray([0.5]).astype("float32") + zp_np = np.asarray([127]).astype("int32") + dtype_np = np.int32(SQNN_DTYPE_TO_CODE["uint8"]) + quant_args = {"in_zero_point": zp_np[0], "in_scale": scale_np[0]} + dq_out = dequantize_test_driver( + in_dtype="uint8", + quant_args=quant_args, + axis=-1, + in_data=data, + ) + # Create variables with undefined shape and run with scalar inputs. + input_data = relay.var("input_data", shape=data.shape, dtype="float32") + scale = relay.var("scale", shape=[relay.Any()], dtype="float32") + zp = relay.var("zp", shape=[relay.Any()], dtype="int32") + dtype = relay.var("dtype", shape=[]) + vm = build_simulated_dequantize(input_data, scale, zp, dtype) + sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_dq_out.asnumpy(), dq_out) + + # Now test int8 to float32 compilation. + data = np.random.uniform(low=0, high=255, size=[2, 5]).astype("int8") + data_fp = data.astype("float32") + # Get the reference quantize output. + dq_out = dequantize_test_driver( + in_dtype="int8", + quant_args=quant_args, + axis=-1, + in_data=data, + ) + # Run the simulated quantize without recompiling and confirm results match. + dtype_np = np.int32(SQNN_DTYPE_TO_CODE["int8"]) + sim_dq_out = vm.invoke("main", input_data=data_fp, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_dq_out.asnumpy(), dq_out) + + +if __name__ == "__main__": + test_simulated_dequantize() + test_dynamic_channels() + test_dynamic_dtype() diff --git a/tests/python/relay/test_op_qnn_simulated_quantize.py b/tests/python/relay/test_op_qnn_simulated_quantize.py new file mode 100644 index 000000000000..ee4ba209dcb8 --- /dev/null +++ b/tests/python/relay/test_op_qnn_simulated_quantize.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime +from tvm.runtime.vm import VirtualMachine +from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE + + +def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data): + shape = in_data.shape + input_data = relay.var("input_data", shape=shape, dtype=in_dtype) + output_zero_point = relay.const(quant_args["out_zero_point"]) + output_scale = relay.const(quant_args["out_scale"]) + quantized_output = relay.qnn.op.quantize( + input_data, + output_scale=output_scale, + output_zero_point=output_zero_point, + axis=axis, + out_dtype=out_dtype, + ) + mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) + mod = tvm.IRModule.from_expr(mod) + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, "llvm", params=None) + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input(input_data=in_data) + rt_mod.set_input(**params) + rt_mod.run() + res = rt_mod.get_output(0).asnumpy() + return res + + +def build_simulated_quantize(input_data, scale, zp, dtype, axis=-1): + sim_q = relay.qnn.op.simulated_quantize( + input_data, + scale, + zp, + axis=axis, + out_dtype=dtype, + ) + mod = tvm.IRModule.from_expr(sim_q) + with tvm.transform.PassContext(opt_level=3): + vm_exec = relay.vm.compile(mod, "llvm", params=None) + vm = VirtualMachine(vm_exec, tvm.cpu(0)) + return vm + + +def verify_simulated_quantize_simple(dtype): + data = np.random.uniform(low=-128, high=127, size=[2, 5]).astype("float32") + scale_np = np.float32(0.5) + zp_np = np.int32(127) + dtype_np = np.int32(SQNN_DTYPE_TO_CODE[dtype]) + quant_args = {"out_zero_point": zp_np, "out_scale": scale_np} + q_out = quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype=dtype, + in_data=data, + ) + input_data = relay.var("input_data", shape=data.shape, dtype="float32") + scale = relay.var("scale", shape=[]) + zp = relay.var("zp", shape=[]) + dtype = relay.var("dtype", shape=[]) + vm = build_simulated_quantize(input_data, scale, zp, dtype) + sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_q_out.asnumpy(), q_out) + + +def test_simulated_quantize(): + verify_simulated_quantize_simple("uint8") + verify_simulated_quantize_simple("int8") + verify_simulated_quantize_simple("int32") + + +def test_dynamic_channels(): + # Compile simulated quantize once but support either per-channel or scalar params. + data = np.random.uniform(low=-64, high=64, size=[2, 5]).astype("float32") + # Test scalar qnn params. + scale_np = np.asarray([0.5]).astype("float32") + zp_np = np.asarray([127]).astype("int32") + dtype_np = np.int32(SQNN_DTYPE_TO_CODE["uint8"]) + quant_args = {"out_zero_point": zp_np[0], "out_scale": scale_np[0]} + q_out = quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=0, + out_dtype="uint8", + in_data=data, + ) + # Create variables with undefined shape and run with scalar inputs. + input_data = relay.var("input_data", shape=data.shape, dtype="float32") + scale = relay.var("scale", shape=[relay.Any()], dtype="float32") + zp = relay.var("zp", shape=[relay.Any()], dtype="int32") + dtype = relay.var("dtype", shape=[]) + vm = build_simulated_quantize(input_data, scale, zp, dtype, axis=0) + sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_q_out.asnumpy(), q_out) + + # Now get the perchannel quantize output and compare without recompiling. + scale_np = np.array([0.5, 0.25]).astype("float32") + zp_np = np.array([127, 123]).astype("int32") + + # Get the reference quantize output. + quant_args = {"out_zero_point": zp_np, "out_scale": scale_np} + q_out = quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=0, + out_dtype="uint8", + in_data=data, + ) + # Run the simulated quantize without recompiling and confirm results match. + sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_q_out.asnumpy(), q_out) + + +def test_dynamic_dtype(): + # Compile simulated quantize once but support any type of quantization. + data = np.random.uniform(low=-64, high=64, size=[2, 5]).astype("float32") + # Test scalar float32 to uint8. + scale_np = np.asarray([0.5]).astype("float32") + zp_np = np.asarray([127]).astype("int32") + dtype_np = np.int32(SQNN_DTYPE_TO_CODE["uint8"]) + quant_args = {"out_zero_point": zp_np[0], "out_scale": scale_np[0]} + q_out = quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype="uint8", + in_data=data, + ) + # Create variables with undefined shape and run with scalar inputs. + input_data = relay.var("input_data", shape=data.shape, dtype="float32") + scale = relay.var("scale", shape=[relay.Any()], dtype="float32") + zp = relay.var("zp", shape=[relay.Any()], dtype="int32") + dtype = relay.var("dtype", shape=[]) + vm = build_simulated_quantize(input_data, scale, zp, dtype) + sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_q_out.asnumpy(), q_out) + + # Now test float32 to int32 compilation. + # Get the reference quantize output. + q_out = quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype="int32", + in_data=data, + ) + # Run the simulated quantize without recompiling and confirm results match. + dtype_np = np.int32(SQNN_DTYPE_TO_CODE["int32"]) + sim_q_out = vm.invoke("main", input_data=data, scale=scale_np, zp=zp_np, dtype=dtype_np) + np.testing.assert_equal(sim_q_out.asnumpy(), q_out) + + +if __name__ == "__main__": + test_simulated_quantize() + test_dynamic_channels() + test_dynamic_dtype() diff --git a/tests/python/topi/python/test_topi_qnn.py b/tests/python/topi/python/test_topi_qnn.py new file mode 100644 index 000000000000..a63f34fe08d0 --- /dev/null +++ b/tests/python/topi/python/test_topi_qnn.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for QNN operators.""" +import numpy as np +import tvm +from tvm import topi, relay, te +from tvm.contrib import graph_runtime +import tvm.topi.testing + + +def verify_simulated_quantize(data_shape, out_dtype, channels, axis): + # Create placeholder variables for all qnn inputs. + A = te.placeholder(data_shape, name="value", dtype="float32") + D = te.placeholder([], name="dtype", dtype="int32") + S = te.placeholder([te.size_var("scale_dim")], name="scale", dtype="float32") + Z = te.placeholder([te.size_var("zp_dim")], name="zp", dtype="int32") + SIM_Q = topi.nn.simulated_quantize(A, D, output_scale=S, output_zero_point=Z, axis=axis) + + # Create random numpy values to assign to inputs. + a_np = np.random.uniform(size=data_shape).astype("float32") + d_np = np.int32(topi.nn.SQNN_DTYPE_TO_CODE[out_dtype]) + s_np = np.random.uniform(low=1e-4, high=0.1, size=channels).astype("float32") + z_np = np.random.uniform(low=-10, high=10, size=channels).astype("int32") + q_np = np.zeros(shape=data_shape, dtype="float32") + + def check_device(device, ctx): + # Wrap the numpy arrays in nd arrays. + a = tvm.nd.array(a_np, ctx) + d = tvm.nd.array(d_np, ctx) + s = tvm.nd.array(s_np, ctx) + z = tvm.nd.array(z_np, ctx) + q = tvm.nd.array(q_np, ctx) + + # Construct equivalent relay graph. + per_channel = channels[0] != 1 + a_var = relay.var("a", shape=data_shape, dtype="float32") + if per_channel: + s_var = relay.const(s_np) + z_var = relay.const(z_np) + else: + s_var = relay.const(s_np[0]) + z_var = relay.const(z_np[0]) + real_q_op = relay.qnn.op.quantize(a_var, s_var, z_var, axis=axis, out_dtype=out_dtype) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(tvm.IRModule.from_expr(real_q_op), target=device) + + # Get real qnn quantize output. + m = graph_runtime.GraphModule(lib["default"](ctx)) + m.set_input("a", a_np) + + m.run() + real_q_out = m.get_output(0) + + # Compile the simulated quantize function. + with tvm.target.Target(device): + sched = tvm.topi.testing.get_injective_schedule(device)(SIM_Q) + func = tvm.build(sched, [A, D, S, Z, SIM_Q], device, name="sim_quantize") + func(a, d, s, z, q) + + # Check correctness against the true qnn output. + tvm.testing.assert_allclose(q.asnumpy(), real_q_out.asnumpy().astype("float32")) + + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) + + +def test_simulated_quantize(): + verify_simulated_quantize([1], "int8", [1], -1) + verify_simulated_quantize([2, 5], "int8", [5], 1) + verify_simulated_quantize([1, 32, 32, 32], "int8", [32], -1) + verify_simulated_quantize([1, 32, 32, 32], "uint8", [32], -2) + verify_simulated_quantize([2, 5], "int32", [5], 1) + + +def verify_simulated_dequantize(data_shape, in_dtype, channels, axis): + # Create placeholder variables for all qnn inputs. + A = te.placeholder(data_shape, name="value", dtype="float32") + D = te.placeholder([], name="dtype", dtype="int32") + S = te.placeholder([te.size_var("scale_dim")], name="scale", dtype="float32") + Z = te.placeholder([te.size_var("zp_dim")], name="zp", dtype="int32") + SIM_DQ = topi.nn.simulated_dequantize(A, D, input_scale=S, input_zero_point=Z, axis=axis) + + # Create random numpy values to assign to inputs. + a_np = np.random.uniform(low=-128, high=127, size=data_shape).astype(in_dtype) + a_np_f = a_np.astype("float32") + d_np = np.int32(topi.nn.SQNN_DTYPE_TO_CODE[in_dtype]) + s_np = np.random.uniform(low=1e-4, high=0.1, size=channels).astype("float32") + z_np = np.random.uniform(low=-10, high=10, size=channels).astype("int32") + dq_np = np.zeros(shape=data_shape, dtype="float32") + + def check_device(device, ctx): + # Wrap the numpy arrays in nd arrays. + a = tvm.nd.array(a_np_f, ctx) + d = tvm.nd.array(d_np, ctx) + s = tvm.nd.array(s_np, ctx) + z = tvm.nd.array(z_np, ctx) + dq = tvm.nd.array(dq_np, ctx) + + # Construct equivalent relay graph. + per_channel = channels[0] != 1 + a_var = relay.var("a", shape=data_shape, dtype=in_dtype) + if per_channel: + s_var = relay.const(s_np) + z_var = relay.const(z_np) + else: + s_var = relay.const(s_np[0]) + z_var = relay.const(z_np[0]) + real_dq_op = relay.qnn.op.dequantize(a_var, s_var, z_var, axis=axis) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(tvm.IRModule.from_expr(real_dq_op), target=device) + + # Get real qnn quantize output. + m = graph_runtime.GraphModule(lib["default"](ctx)) + m.set_input("a", a_np) + + m.run() + real_dq_out = m.get_output(0) + + # Compile the simulated quantize function. + with tvm.target.Target(device): + sched = tvm.topi.testing.get_injective_schedule(device)(SIM_DQ) + func = tvm.build(sched, [A, D, S, Z, SIM_DQ], device, name="sim_quantize") + func(a, d, s, z, dq) + + # Check correctness against the true qnn output. + tvm.testing.assert_allclose(dq.asnumpy(), real_dq_out.asnumpy().astype("float32")) + + for target, ctx in tvm.testing.enabled_targets(): + check_device(target, ctx) + + +def test_simulated_dequantize(): + verify_simulated_dequantize([1], "int8", [1], -1) + verify_simulated_dequantize([2, 5], "int8", [5], 1) + verify_simulated_dequantize([2, 5], "int8", [2], 0) + verify_simulated_dequantize([1, 32, 32, 32], "int8", [32], -1) + verify_simulated_dequantize([1, 32, 32, 32], "uint8", [32], -2) + verify_simulated_dequantize([2, 5], "int32", [5], 1) + + +if __name__ == "__main__": + test_simulated_quantize() + test_simulated_dequantize()