diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h new file mode 100644 index 000000000000..afbe1f343bb8 --- /dev/null +++ b/include/tvm/ir/affine_type.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/affine_type.h + * \brief Quantized Tensor Types. + */ +#ifndef TVM_IR_AFFINE_TYPE_H_ +#define TVM_IR_AFFINE_TYPE_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief AffineType representation + * \sa AffineType + */ +class AffineTypeNode : public Object { + public: + /*! + * \brief Span that points to the original source code. + * Reserved debug information. + */ + mutable Span span; + + static constexpr const char* _type_key = "AffineType"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); +}; + +/*! + * \brief Managed reference to AffineTypeNode. + * \sa AffineTypeNode + */ +class AffineType : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); +}; + +/*! + * \brief TensorAffineType representation + * \sa TensorAffineType + * + * This Type represents a quantized integer tensor that can be converted + * back to real space via the x_real = scale * (x_quant - zero_point) + */ +class TensorAffineTypeNode : public AffineTypeNode { + public: + /*! \brief The scale of this type */ + RelayExpr scale; + /*! \brief The zero point of this type */ + RelayExpr zero_point; + /*! \brief The data type of this type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("scale", &scale); + v->Visit("zero_point", &zero_point); + v->Visit("dtype", &dtype); + } + + bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(scale, other->scale) && equal(zero_point, other->zero_point) && + equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(scale); + hash_reduce(zero_point); + hash_reduce(dtype); + } + + static constexpr const char* _type_key = "TensorAffineType"; + TVM_DECLARE_BASE_OBJECT_INFO(TensorAffineTypeNode, AffineTypeNode); +}; + +/*! + * \brief Managed reference to AffineTypes. + * \sa AffineTypeNode + */ +class TensorAffineType : public AffineType { + public: + TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); +}; + +/*! + * \brief TupleAffineType representation + * \sa TupleAffineType + */ +class TupleAffineTypeNode : public AffineTypeNode { + public: + /*! \brief The types of this tuple*/ + Array types; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("types", &types); } + + bool SEqualReduce(const TupleAffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(types, other->types); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(types); + } + + static constexpr const char* _type_key = "TupleAffineType"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleAffineTypeNode, AffineTypeNode); +}; + +/*! + * \brief Managed reference to TupleAffineTypes. + * \sa TupleAffineType + */ +class TupleAffineType : public AffineType { + public: + TVM_DLL TupleAffineType(Array types); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleAffineType, AffineType, TupleAffineTypeNode); +}; + +} // namespace tvm +#endif // TVM_IR_AFFINE_TYPE_H_ diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index b4cc4421b169..83557a3eae19 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -21,6 +21,7 @@ from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType +from .affine_type import TensorAffineType, TupleAffineType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range from .op import Op, register_op_attr, register_intrin_lowering diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py new file mode 100644 index 000000000000..a1ce08017b1b --- /dev/null +++ b/python/tvm/ir/affine_type.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Types for quantized Tensors.""" +import tvm._ffi + +from .base import Node +from . import _ffi_api + + +class AffineType(Node): + """The base class of Affine Types.""" + + def __eq__(self, other): + """Compare two types for structural equivalence.""" + return bool(tvm.ir.structural_equal(self, other)) + + def __ne__(self, other): + return not self.__eq__(other) + + +@tvm._ffi.register_object("TensorAffineType") +class TensorAffineType(AffineType): + """The quantized type of a tensor, with scale, zero point, and datatype + + The real space value is calculated as x = x_q * scale + zero_point + + Parameters + ---------- + scale: Expr + The scale + + zero_point: Expr + The zero_point + + dtype : str + The content data type. + """ + + def __init__(self, scale, zero_point, dtype): + self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype) + + +@tvm._ffi.register_object("TupleAffineType") +class TupleAffineType(AffineType): + """Affine types of a node with multiple outputs + + Parameters + ---------- + types : List[TensorAffineType] + The shape of the Tensor + + """ + + def __init__(self, types): + self.__init_handle_by_constructor__(_ffi_api.TupleAffineType, types) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 42bde838859a..c12e096e9051 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1008,7 +1008,7 @@ def _impl_v11(cls, inputs, attr, params): if len(inputs) == 3: value = fold_constant(_op.take(inputs[2], _op.const(0))) else: - value = 0 + value = 0.0 pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1)))) pad_mode = attr.get("mode", b"constant").decode("utf-8") diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 5f4c53772eec..783204fb700f 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -17,13 +17,12 @@ """Relay functions for rewriting fake quantized ops.""" import tvm from tvm import relay +from tvm.ir import TensorAffineType, TupleAffineType from ..op import register_fake_quantization_to_integer def fold_constant(expr): - mod = tvm.IRModule.from_expr(expr) - mod = relay.transform.FoldConstant()(mod) - return mod["main"].body + return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) @register_fake_quantization_to_integer("qnn.dequantize") @@ -31,7 +30,7 @@ def dequantize(expr, type_map): """Remove dequantize op""" out = expr.args[0] t = type_map[expr] - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] @register_fake_quantization_to_integer("qnn.quantize") @@ -54,23 +53,26 @@ def quantize(expr, type_map): expr.args[2], out_dtype=expr.attrs.out_dtype, ) - return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] + return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] -def register_unary_identity(op_name, op): +def register_unary_identity(op_name): def identity(expr, type_map): assert len(expr.args) == 1 arg = expr.args[0] t = type_map[arg] - out = op(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] + return [expr, t] return register_fake_quantization_to_integer(op_name, identity) -register_unary_identity("reshape", relay.op.reshape) -register_unary_identity("transpose", relay.op.transpose) -register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d) +register_unary_identity("reshape") +register_unary_identity("squeeze") +register_unary_identity("strided_slice") +register_unary_identity("transpose") +register_unary_identity("expand_dims") +register_unary_identity("nn.max_pool2d") +register_unary_identity("nn.batch_flatten") @register_fake_quantization_to_integer("nn.avg_pool2d") @@ -81,7 +83,7 @@ def avgpool2d(expr, type_map): arg = relay.op.cast(arg, "int32") out = relay.op.nn.avg_pool2d(arg, **expr.attrs) out = relay.op.cast(out, t.dtype) - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] @register_fake_quantization_to_integer("nn.bias_add") @@ -99,10 +101,10 @@ def bias_add(expr, type_map): b_t.zero_point, in_scale, in_zero_point, - out_dtype=xt.dtype, + out_dtype=x_t.dtype, ) out = relay.op.nn.bias_add(x, b, **expr.attrs) - return [out, x_t.scale, x_t.zero_point, x_t.dtype] + return [out, x_t] @register_fake_quantization_to_integer("nn.conv2d") @@ -118,7 +120,23 @@ def conv2d(expr, type_map): out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, conv_scale, conv_zp, out.attrs.out_dtype] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)] + + +@register_fake_quantization_to_integer("nn.dense") +def dense(expr, type_map): + """Rewrite a dense op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + dense_scale = fold_constant(x_t.scale * w_t.scale) + dense_zp = relay.const(0) + out = relay.qnn.op.dense( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] @register_fake_quantization_to_integer("concatenate") @@ -126,8 +144,9 @@ def concat(expr, type_map): """Rewrite a concat op""" scales = [] zps = [] - for arg in expr.args[0].fields: - t = type_map[arg] + + tuple_type = type_map[expr.args[0]] + for t in tuple_type.types: scales.append(t.scale) zps.append(t.zero_point) @@ -141,7 +160,21 @@ def concat(expr, type_map): out_type.zero_point, **expr.attrs, ) - return [out, out_type.scale, out_type.zero_point, out_type.dtype] + return [out, out_type] + + +@register_fake_quantization_to_integer("split") +def split(expr, type_map): + """Rewrite a split op""" + arg = expr.args[0] + t = type_map[arg] + attrs = {**expr.attrs} + if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): + num_split = attrs["indices_or_sections"].value + attrs["indices_or_sections"] = num_split + else: + num_split = len(attrs["indices_or_sections"]) + 1 + return [expr, TupleAffineType([t] * num_split)] @register_fake_quantization_to_integer("clip") @@ -163,4 +196,133 @@ def clip(expr, type_map): amin = relay.op.round(relay.op.const(amin) / scale + z_p) amax = relay.op.round(relay.op.const(amax) / scale + z_p) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) - return [out, t.scale, t.zero_point, t.dtype] + return [out, t] + + +@register_fake_quantization_to_integer("nn.pad") +def pad(expr, type_map): + """Rewite an nn.pad op""" + arg = expr.args[0] + t = type_map[arg] + pad_value = expr.args[1] + ## TF2ONNX will sometimes implement the pad_value as a constant without a quantize + ## To support that, the pass lets branches that terminate in a constant through + if pad_value in type_map: + ## if the pad value is calcuated from a dequantize op, it should be in the type map + ## and we need to make sure it's affine type matches the arg + pad_t = type_map[pad_value] + if not tvm.ir.structural_equal(t, pad_t): + pad_value = relay.qnn.op.requantize( + pad_value, + pad_t.scale, + pad_t.zero_point, + t.scale, + t.zero_point, + out_dtype=t.dtype, + ) + else: + ## If the pad-value is a constant, we need to quantize it + assert isinstance(pad_value, relay.expr.Constant) + pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point) + + out = relay.op.nn.pad(arg, pad_value=pad_value, **expr.attrs) + return [out, t] + + +def get_binary_types(expr, type_map): + """Get Affine types of a binary op's inputs and unify them""" + ##Support the case where one input is quantized and the other is a constant float + left = expr.args[0] + right = expr.args[1] + left_t = None + right_t = None + + if left in type_map: + left_t = type_map[left] + if right in type_map: + right_t = type_map[right] + + out_t = type_map[expr] + if left_t is None and right_t is None: + raise TypeError("neither input is quantized!") + if left_t is None: + assert isinstance(left, relay.expr.Constant) + left = relay.qnn.op.quantize( + left, right_t.scale, right_t.zero_point, out_dtype=right_t.dtype + ) + left_t = right_t + out_t = right_t + if right_t is None: + assert isinstance(right, relay.expr.Constant) + right = relay.qnn.op.quantize( + right, left_t.scale, left_t.zero_point, out_dtype=left_t.dtype + ) + right_t = left_t + out_t = left_t + + # Handle the case of mismatched inputs + if not left_t.dtype == out_t.dtype: + out_t = left_t + + return left, right, left_t, right_t, out_t + + +def register_binary_qnn(op_name, op): + """Register a Binary Op that converts to QNN""" + + def binary(expr, type_map): + left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + out = op( + left, + right, + left_t.scale, + left_t.zero_point, + right_t.scale, + right_t.zero_point, + out_t.scale, + out_t.zero_point, + ) + return [out, out_t] + + return register_fake_quantization_to_integer(op_name, binary) + + +# Use lambdas here to avoid a circular import problem +# pylint: disable=unnecessary-lambda +register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args)) +register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args)) +register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args)) + + +def register_binary_identity(op_name, op): + """Register a binary op that works directly on int8""" + + def binary(expr, type_map): + left, right, left_t, right_t, out_t = get_binary_types(expr, type_map) + if left_t != out_t: + left = relay.qnn.op.requantize( + left, + left_t.scale, + left_t.zero_point, + out_t.scale, + out_t.zero_point, + out_dtype=out_t.dtype, + ) + + if right_t != out_t: + right = relay.qnn.op.requantize( + right, + right_t.scale, + right_t.zero_point, + out_t.scale, + out_t.zero_point, + out_dtype=out_t.dtype, + ) + out = op(left, right) + return [out, out_t] + + return register_fake_quantization_to_integer(op_name, binary) + + +register_binary_identity("minimum", relay.op.minimum) +register_binary_identity("maximum", relay.op.maximum) diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc new file mode 100644 index 000000000000..3454b6011c9b --- /dev/null +++ b/src/ir/affine_type.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/affine_type.cc + * \brief The Type information for quantized nodes. + */ +#include +#include +#include + +namespace tvm { + +using tvm::ReprPrinter; +using namespace tvm::runtime; + +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) { + ObjectPtr n = make_object(); + n->scale = std::move(scale); + n->zero_point = std::move(zero_point); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); + +TVM_REGISTER_GLOBAL("ir.TensorAffineType") + .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) { + return TensorAffineType(scale, zero_point, dtype); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " + << node->dtype << ")"; + }); + +TupleAffineType::TupleAffineType(Array types) { + ObjectPtr n = make_object(); + n->types = std::move(types); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleAffineTypeNode); + +TVM_REGISTER_GLOBAL("ir.TupleAffineType").set_body_typed([](Array types) { + return TupleAffineType(types); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleAffineType(["; + for (size_t i = 0; i < node->types.size(); ++i) { + p->stream << node->types[i]; + if (i < node->types.size() - 1) { + p->stream << ", "; + } + } + p->stream << "])"; + }); + +} // namespace tvm diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 751abfc5ca81..2f1d7d8da16c 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -51,14 +51,20 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis < 0) ? data->shape.size() + axis : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << quantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 769f37205790..46de3522061b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -279,14 +279,20 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << requantize_attrs->axis << " is out of range"; + auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << requantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; + PrimExpr axis_shape; + if (rank > 0) { + axis_shape = data->shape[axis]; + } else { + axis_shape = Integer(1); + } // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt + AssignType(types[1], DataType::Float(32), axis_shape, reporter); // input_scale + AssignType(types[2], DataType::Int(32), axis_shape, reporter); // input_zero_pt // For now, requantize output tensor is limited to full tensor uniform quantization. ICHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale ICHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index f883b4113656..b5f434e74c43 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -23,10 +23,14 @@ * to actual integer operations. */ +#include #include #include #include +namespace tvm { +namespace relay { + /* Description of FakeQuantizationToInteger * * The purpose of this pass is to find regions of the graph that follow @@ -63,65 +67,6 @@ * rewritten subgraph and the processing continues */ -namespace tvm { -namespace relay { - -/*! - * \brief AffineType representation - * \sa AffineType - */ -class AffineTypeNode : public Object { - public: - /*! \brief The scale of this type */ - Expr scale; - /*! \brief The zero point of this type */ - Expr zero_point; - /*! \brief The data type of this type */ - DataType dtype; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("scale", &scale); - v->Visit("zero_point", &zero_point); - v->Visit("dtype", &dtype); - } - - bool SEqualReduce(const AffineTypeNode* other, SEqualReducer equal) const { - equal->MarkGraphNode(); - return equal(scale, other->scale) && equal(zero_point, other->zero_point) && - equal(dtype, other->dtype); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce->MarkGraphNode(); - hash_reduce(scale); - hash_reduce(zero_point); - hash_reduce(dtype); - } - - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const char* _type_key = "AffineTypeNode"; - TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); -}; - -/*! - * \brief Managed reference to AffineTypes. - * \sa AffineTypeNode - */ -class AffineType : public ObjectRef { - public: - TVM_DLL AffineType(Expr scale, Expr zero_point, DataType dtype) { - ObjectPtr n = make_object(); - n->scale = std::move(scale); - n->zero_point = std::move(zero_point); - n->dtype = std::move(dtype); - data_ = std::move(n); - } - TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); -}; - -TVM_REGISTER_NODE_TYPE(AffineTypeNode); - using ExprSet = std::unordered_set; using ExprMap = std::unordered_map; using AffineTypeMap = Map; @@ -147,8 +92,14 @@ class SubgraphExtractor : public ExprVisitor { } const AffineTypeMap GetAffineTypes() { return affine_types_; } void VisitExpr(const Expr& expr) override { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. if (expr.as() == nullptr && expr.as() == nullptr && - expr.as() == nullptr) { + expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + LOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" + << " a fake quantize region, aborting this rewrite"; is_fake_quantized_ = false; } else { ExprVisitor::VisitExpr(expr); @@ -162,13 +113,14 @@ class SubgraphExtractor : public ExprVisitor { VisitExpr(call_node->args[0]); // Collect type of quantize ops affine_types_.Set(GetRef(call_node), - AffineType(call_node->args[1], call_node->args[2], - call_node->checked_type().as()->dtype)); + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->checked_type().as()->dtype)); } else if (call_node->op == dequantize_op_) { // Collect type of dequantize ops - affine_types_.Set(GetRef(call_node), - AffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype)); + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); @@ -225,19 +177,38 @@ class SubgraphMutator : public ExprMutator { } // Call the rewrite Array vals = fqfq[op](expr, affine_types_); - // Save teh outputs of the rewrite - ICHECK(vals.size() == 4) + // Save the outputs of the rewrite + ICHECK(vals.size() == 2) << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " << AsText(op, false); out = Downcast(vals[0]); - affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), - DataType(String2DLDataType(Downcast(vals[3]))))); + affine_types_.Set(out, Downcast(vals[1])); } else { ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " << AsText(GetRef(call_node), false); } return out; } + + Expr VisitExpr_(const TupleNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto new_node = expr.as(); + Array types; + for (Expr field : new_node->fields) { + ICHECK(affine_types_[field].as()); + types.push_back(Downcast(affine_types_[field])); + } + affine_types_.Set(expr, TupleAffineType(types)); + return expr; + } + + Expr VisitExpr_(const TupleGetItemNode* node) { + Expr expr = ExprMutator::VisitExpr_(node); + auto tuple_type = affine_types_[expr.as()->tuple].as(); + affine_types_.Set(expr, tuple_type->types[node->index]); + return expr; + } + ExprSet subgraph_; AffineTypeMap affine_types_; AffineType out_type_; diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 345e8b815da1..322382ca002c 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -88,6 +88,20 @@ def test_float32_to_int8(): ) +def test_scalar_float32_to_int8(): + data = np.array(-63.5).astype("float32") + output = np.array(-128).astype("int8") + quant_args = {"out_zero_point": np.int32(-1), "out_scale": np.float32(0.5)} + quantize_test_driver( + in_dtype="float32", + quant_args=quant_args, + axis=-1, + out_dtype="int8", + in_data=data, + verify_output_data=output, + ) + + def test_channelwise_axis_0(): data = ( np.array([-63.5, -63, -62.5, -62, -61.5, 30, 31, 31.5, 31.75, 32]) @@ -163,6 +177,7 @@ def test_dynamic_quantize(): if __name__ == "__main__": test_float32_to_uint8() test_float32_to_int8() + test_scalar_float32_to_int8() test_channelwise_axis_0() test_channelwise_axis_1() test_dynamic_quantize() diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index ad9805e74929..0f512df25cdf 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -92,6 +92,24 @@ def test_same_scale(): verify(mod, (golden_data, golden_output)) +def test_scalar_same_scale(): + # Have same scales, everything within range + golden_data = np.array(-10).astype("int32") + golden_output = golden_data + + for rounding in roundings: + mod = get_mod( + data_shape=(), + data_dtype="int32", + out_dtype="int8", + input_scale=0.5, + output_scale=0.5, + rounding=rounding, + ) + assert "right_shift" not in mod.astext() + verify(mod, (golden_data, golden_output)) + + def test_downscale(): for rounding in roundings: mod = get_mod( @@ -437,6 +455,7 @@ def test_per_channel_different_scale(): if __name__ == "__main__": test_same_scale() + test_scalar_same_scale() test_downscale() test_upscale() test_non_power_of_two() diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 3271379cf3ef..1e7d749ff418 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -22,6 +22,25 @@ from tvm import relay +def compare_fq_to_int(expr, args, allow_rounding_error=False): + mod = tvm.IRModule.from_expr(expr) + mod = tvm.relay.transform.InferType()(mod) + + mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod) + assert not tvm.ir.structural_equal(mod, mod_int) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(*args).numpy() + + ex = relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") + result_int = ex.evaluate()(*args).numpy() + + if allow_rounding_error: + assert np.all(np.abs(result - result_int) <= 1) + else: + assert np.array_equal(result, result_int) + + def test_fake_quantize_conv(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") @@ -35,23 +54,29 @@ def test_fake_quantize_conv(): ) op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np, w_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np).asnumpy() +def test_fake_quantize_dense(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) - assert np.array_equal(result, result2) + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) def test_fake_transpose_quantize_conv(): @@ -65,23 +90,10 @@ def test_fake_transpose_quantize_conv(): op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) op = relay.qnn.op.quantize(op, one, zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) - - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np).asnumpy() - - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np).asnumpy() - - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np, w_np]) def test_fake_transpose_quantize_conv_bias_add(): @@ -97,24 +109,32 @@ def test_fake_transpose_quantize_conv_bias_add(): op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) op = relay.qnn.op.quantize(op, one, zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np, w_np, bias_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, w_np, bias_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np, w_np, bias_np).asnumpy() +def test_fake_transpose_quantize_conv_bias_add_mismatch(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + two = relay.const(2.0) + zero = relay.const(0) - assert np.array_equal(result, result2) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, two, zero)) + op = relay.qnn.op.quantize(op, one, zero) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np]) def test_fake_quantize_maxpool(): @@ -125,101 +145,121 @@ def test_fake_quantize_maxpool(): op = relay.op.nn.max_pool2d(x, [3, 3]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_avgpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np], True) -def test_fake_quantize_avgpool(): +def test_fake_quantize_reshape(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.op.reshape(x, [1, 3, -1]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_expand_dims(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.expand_dims(x, axis=1) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_squeeze(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.squeeze(x, axis=[0]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.all(np.abs(result - result2) <= 1) + compare_fq_to_int(op, [x_np]) -def test_fake_quantize_reshape(): +def test_fake_quantize_strided_slice(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.reshape(x, [1, 3, -1]) + op = relay.op.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 1, 112, 112]) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_split(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.split(x, axis=3, indices_or_sections=2) + op = relay.qnn.op.quantize(op[0], relay.const(2.0), zero) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() + op = relay.op.split(x, axis=3, indices_or_sections=[56, 112, 168]) + op = relay.qnn.op.quantize(op[1], relay.const(2.0), zero) - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) -def test_fake_quantize_transpose_reshape(): +def test_fake_quantize_batch_flatten(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") zero = relay.const(0) x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) - op = relay.op.transpose(x, [1, 0, 2, 3]) - op = relay.op.reshape(op, [3, -1]) + op = relay.op.nn.batch_flatten(x) op = relay.qnn.op.quantize(op, relay.const(2.0), zero) - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np]) def test_fake_quantize_concat(): @@ -234,24 +274,11 @@ def test_fake_quantize_concat(): concat = relay.op.concatenate(inputs, axis=1) out = relay.qnn.op.quantize(concat, relay.const(3.5), zero) - mod = tvm.IRModule.from_expr(out) - mod = tvm.relay.transform.InferType()(mod) - inputs_np = [] for i in range(4): inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) - - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(*inputs_np).asnumpy() - - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(*inputs_np).asnumpy() - - assert np.array_equal(result, result2) + compare_fq_to_int(out, inputs_np) def test_fake_quantize_clip(): @@ -261,19 +288,67 @@ def test_fake_quantize_clip(): op = relay.op.clip(x, 0, 6) op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") - mod = tvm.IRModule.from_expr(op) - mod = tvm.relay.transform.InferType()(mod) - x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") - mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) - assert not tvm.ir.structural_equal(mod, mod2) - mod2 = tvm.relay.transform.FoldConstant()(mod2) + compare_fq_to_int(op, [x_np]) - ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np).asnumpy() - ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") - result2 = ex.evaluate()(x_np).asnumpy() +@pytest.mark.parametrize( + "operator", + [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], +) +def test_fake_quantize_binary(operator): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(0)) + + y = relay.var("y", shape=[1, 3, 224, 224], dtype="int8") + y = relay.qnn.op.dequantize(y, relay.const(0.2), relay.const(0)) + + op = operator(x, y) + if operator == relay.op.multiply: + out_scale = relay.const(20.0) + else: + out_scale = relay.const(0.1) + + op = relay.qnn.op.quantize(op, out_scale, relay.const(0), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + y_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np, y_np]) + + +@pytest.mark.parametrize( + "operator", + [ + relay.op.add, + relay.op.multiply, + relay.op.subtract, + relay.op.subtract, + relay.op.minimum, + relay.op.maximum, + ], +) +def test_fake_quantize_binary_const(operator): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(0.1), relay.const(10)) + + y = relay.const(1.0) + + op = operator(x, y) + op = relay.qnn.op.quantize(op, relay.const(0.1), relay.const(10), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_pad(): + x = relay.var("x", shape=[1, 383, 128], dtype="int8") + x = relay.qnn.op.dequantize(x, relay.const(1.0), relay.const(10)) + op = relay.op.nn.pad(x, [[0, 0], [0, 1], [0, 0]], 0.0) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(10), out_dtype="int8") + + x_np = np.random.randint(-25, 25, size=[1, 383, 128], dtype="int8") - assert np.array_equal(result, result2) + compare_fq_to_int(op, [x_np])