From 74baf1cdc948f8bcb44100ad674c9e916156c2e4 Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Tue, 13 Jul 2021 08:18:21 +0300 Subject: [PATCH] Add qnn batch_matmul operator (#8401) * Add qnn batch_matmul operator - add support of the different out type for x86 batch_matmul * Fix code style * Add out_dtype to generic batch_matmul * Restore fixe in batch_matmul for dynamic shapes * Fix documentation for qnn.batch_matmul * Remove debug code * Modify zero point for qnn batch_matmul test --- python/tvm/relay/op/strategy/x86.py | 6 +- python/tvm/relay/qnn/op/qnn.py | 38 +++ python/tvm/topi/nn/batch_matmul.py | 28 +- python/tvm/topi/x86/batch_matmul.py | 25 +- src/relay/op/nn/nn.cc | 53 +--- src/relay/op/nn/nn.h | 54 ++++ src/relay/qnn/op/batch_matmul.cc | 216 +++++++++++++++ .../python/relay/test_op_qnn_batch_matmul.py | 247 ++++++++++++++++++ 8 files changed, 600 insertions(+), 67 deletions(-) create mode 100644 src/relay/qnn/op/batch_matmul.cc create mode 100644 tests/python/relay/test_op_qnn_batch_matmul.py diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 6a4030514580..a6e141f2753b 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -521,14 +521,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() if is_dynamic(out_type) or is_auto_scheduler_enabled(): strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul, need_auto_scheduler_layout=True), + wrap_compute_batch_matmul( + topi.nn.batch_matmul, need_auto_scheduler_layout=True, need_out_dtype=True + ), wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul), name="batch_matmul.generic", plevel=10, ) else: strategy.add_implementation( - wrap_compute_batch_matmul(topi.x86.batch_matmul), + wrap_compute_batch_matmul(topi.x86.batch_matmul, need_out_dtype=True), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", plevel=10, diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index f02f8227e14a..e74256ec74c3 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -682,6 +682,44 @@ def subtract( ) +def batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype="int32"): + r""" + Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data + in batch. + + .. math:: + + \mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T) + + Parameters + ---------- + x : tvm.relay.Expr + The first quantized input. + A quantized tensor is represented in following manner + `A = scale_a x (QA - zp_A)` + where QA is quantized tensor, scale_a and zp_A are quantization + params. + y : tvm.relay.Expr + The second quantized input. + x_zero_point: tvm.relay.Expr + The first input zero point. + y_zero_point: tvm.relay.Expr + The second input zero point. + x_scale: tvm.relay.Expr + The scale for the first input tensor. + y_scale: tvm.relay.Expr + The scale for the second input tensor. + out_dtype : str, optional + Specifies the output data type for mixed precision dense can be int32 or int16. + + Returns + ------- + result: tvm.relay.Expr + The computed result. + """ + return _make.batch_matmul(x, y, x_zero_point, y_zero_point, x_scale, y_scale, out_dtype) + + # register fuse pattern for qnn ops reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index b6ed5a373e81..a1212668affa 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -21,7 +21,7 @@ from ..utils import get_const_tuple -def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): +def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout="", out_dtype=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. Supports broadcasting for batch dimension. @@ -67,12 +67,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""): N = y.shape[1] oshape = (batch, M, N) - output = te.compute( - oshape, - lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), - tag="batch_matmul", - attrs={"layout_free_placeholders": [y]}, - ) + if out_dtype is None or out_dtype == x.dtype: + output = te.compute( + oshape, + lambda b, i, j: te.sum( + x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k + ), + tag="batch_matmul", + attrs={"layout_free_placeholders": [y]}, + ) + else: + output = te.compute( + oshape, + lambda b, i, j: te.sum( + x[b if XB != 1 else 0, i, k].astype(out_dtype) + * y[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ), + tag="batch_matmul", + attrs={"layout_free_placeholders": [y]}, + ) if auto_scheduler_rewritten_layout: output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 37bdd09d6ca6..35f4a9aba456 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -25,7 +25,7 @@ @autotvm.register_topi_compute("batch_matmul.x86") -def batch_matmul(cfg, x, y, out_shape=None): +def batch_matmul(cfg, x, y, out_shape=None, out_dtype=None): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data in batch. Supports broadcasting in batch dimension. @@ -60,11 +60,24 @@ def batch_matmul(cfg, x, y, out_shape=None): _default_batch_matmul_config(cfg, M, N, K) k = te.reduce_axis((0, K), name="k") - C = te.compute( - (B, M, N), - lambda b, i, j: te.sum(x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k), - tag="batch_matmul", - ) + if out_dtype is None or out_dtype == x.dtype: + C = te.compute( + (B, M, N), + lambda b, i, j: te.sum( + x[b if XB != 1 else 0, i, k] * y[b if YB != 1 else 0, j, k], axis=k + ), + tag="batch_matmul", + ) + else: + C = te.compute( + (B, M, N), + lambda b, i, j: te.sum( + x[b if XB != 1 else 0, i, k].astype(out_dtype) + * y[b if YB != 1 else 0, j, k].astype(out_dtype), + axis=k, + ), + tag="batch_matmul", + ) return C diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d09a8495b549..76a12e27c361 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -935,57 +935,6 @@ If the input has size k on axis 1, then both gamma and beta have shape (k,). // relay.nn.batch_matmul TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs); -bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { - ICHECK_EQ(types.size(), 3); - const auto* x = types[0].as(); - const auto* y = types[1].as(); - if (x == nullptr || y == nullptr) return false; - - const auto* param = attrs.as(); - Array y_shape; - if (param->auto_scheduler_rewritten_layout.size() == 0) { - y_shape = y->shape; - } else { - y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, - {"b", "j", "k"}); - } - - ICHECK(x->shape.size() == 3 && y_shape.size() == 3); - bool is_dyn = false; - Array oshape; - for (size_t i = 0; i < 3; ++i) { - if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { - is_dyn = true; - oshape.push_back(Any()); - } else { - if (i == 0) { - oshape.push_back(max(x->shape[i], y_shape[i])); - } else { - oshape.push_back(x->shape[i]); - } - } - } - if (!is_dyn) { - ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) || - reporter->AssertEQ(y_shape[0], 1)) - << "BatchDot: batch dimensions don't match, " - << " x shape=" << x->shape << ", y shape=" << y_shape; - ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2])) - << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape << ", y shape=" << y_shape; - } - oshape.Set(2, y_shape[1]); - - DataType out_dtype = param->out_dtype; - if (out_dtype.bits() == 0) { - out_dtype = x->dtype; - } - // assign output type - reporter->Assign(types[2], TensorType(oshape, out_dtype)); - return true; -} - // Positional relay function to create batch_matmul operator used by frontend FFI. Expr MakeBatchMatmul(Expr x, Expr y, DataType out_dtype) { auto attrs = make_object(); @@ -1013,7 +962,7 @@ are data in batch. .add_argument("x", "3D Tensor", "First input.") .add_argument("y", "3D Tensor", "Second input.") .set_support_level(10) - .add_type_rel("BatchMatmul", BatchMatmulRel); + .add_type_rel("BatchMatmul", BatchMatmulRel); // relay.nn.cross_entropy bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 29f200c67c59..cf2ec84d1a6e 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -24,10 +24,12 @@ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ +#include #include #include #include +#include #include #include "../op_common.h" @@ -137,6 +139,58 @@ bool DensePackRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +template +bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + + const AttrType* param = attrs.as(); + Array y_shape; + if (param->auto_scheduler_rewritten_layout.size() == 0) { + y_shape = y->shape; + } else { + y_shape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, + {"b", "j", "k"}); + } + + ICHECK(x->shape.size() == 3 && y_shape.size() == 3); + bool is_dyn = false; + Array oshape; + for (size_t i = 0; i < 3; ++i) { + if (x->shape[i].as() != nullptr || y_shape[i].as() != nullptr) { + is_dyn = true; + oshape.push_back(Any()); + } else { + if (i == 0) { + oshape.push_back(max(x->shape[i], y_shape[i])); + } else { + oshape.push_back(x->shape[i]); + } + } + } + if (!is_dyn) { + ICHECK(reporter->AssertEQ(x->shape[0], y_shape[0]) || reporter->AssertEQ(x->shape[0], 1) || + reporter->AssertEQ(y_shape[0], 1)) + << "BatchDot: batch dimensions don't match, " + << " x shape=" << x->shape << ", y shape=" << y_shape; + ICHECK(reporter->AssertEQ(x->shape[2], y_shape[2])) + << "BatchDot: shapes of x and y is inconsistent, " + << " x shape=" << x->shape << ", y shape=" << y_shape; + } + oshape.Set(2, y_shape[1]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = x->dtype; + } + // assign output type + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_NN_H_ diff --git a/src/relay/qnn/op/batch_matmul.cc b/src/relay/qnn/op/batch_matmul.cc new file mode 100644 index 000000000000..bb2b73141afc --- /dev/null +++ b/src/relay/qnn/op/batch_matmul.cc @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/batch_matmul.cc + * \brief Property def of qnn batch_matmul operator. + */ + +#include +#include +#include +#include + +#include "../../op/nn/nn.h" +#include "../../transforms/pattern_utils.h" +#include "../utils.h" + +namespace tvm { +namespace relay { +namespace qnn { + +// relay.op.qnn.batch_matmul + +bool QnnBatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // Expected Types: x, y, x_zero_point, y_zero_point, x_scale, y_scale, + // out_type + ICHECK_EQ(types.size(), 7); + const auto* x = types[0].as(); + const auto* y = types[1].as(); + if (x == nullptr || y == nullptr) return false; + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "BatchMatmulAttrs cannot be nullptr."; + ICHECK(x->dtype == DataType::Int(8) || x->dtype == DataType::UInt(8)) + << "Expected quantized batch_matmul type(int8, uint8) for input but was " << x->dtype; + ICHECK(y->dtype == DataType::Int(8) || y->dtype == DataType::UInt(8)) + << "Expected quantized batch_matmul type(int8, uint8) for weight but was " << y->dtype; + ICHECK(param->out_dtype == DataType::Int(32)) + << "Expected quantized batch_matmul type(int32) for output but was " << param->out_dtype; + + // Check the types of scale and zero points. + for (size_t i = 2; i < 5; ++i) { + if (types[i].as()) { + return false; + } + } + ICHECK(IsScalarType(types[2], DataType::Int(32))); // x_zero_point + ICHECK(IsScalarType(types[3], DataType::Int(32))); // y_zero_point + ICHECK(IsScalarType(types[4], DataType::Float(32))); // x_scale + ICHECK(IsScalarType(types[5], DataType::Float(32))); // y_scale + + ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; + + // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay + // BatchMatmul infer type function. + Array tensor_types = {types[0], types[1], types[6]}; + return BatchMatmulRel(tensor_types, 3, attrs, reporter); +} + +// Positional relay function to create quantized batch_matmul operator used by frontend FFI. +Expr MakeQuantizedBatchMatmul(Expr x, Expr y, Expr x_zero_point, Expr y_zero_point, Expr x_scale, + Expr y_scale, DataType out_dtype) { + auto attrs = make_object(); + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("qnn.batch_matmul"); + return Call(op, {x, y, x_zero_point, y_zero_point, x_scale, y_scale}, Attrs(attrs), {}); +} + +Expr BatchMatmulFirstTerm(const Expr& quantized_x, const Expr& quantized_y, + const BatchMatmulAttrs* attrs) { + return MakeBatchMatmul(quantized_x, quantized_y, attrs->out_dtype); +} + +Expr BatchMatmulSecondTerm(const Expr& x_quantized_data, const Expr& y_zero_point) { + Array axes = {2}; + return Multiply(y_zero_point, Sum(Cast(x_quantized_data, DataType::Int(32)), axes, true, false)); +} + +Expr BatchMatmulThirdTerm(const Expr& y_quantized_data, const Expr& x_zero_point, + int broadcast_dim_size) { + Array axes = {2}; + auto reducemult = + Multiply(x_zero_point, Sum(Cast(y_quantized_data, DataType::Int(32)), axes, true, false)); + Array newshape; + newshape = {1, 1, broadcast_dim_size}; + return Reshape(reducemult, newshape); +} + +Expr BatchMatmulFourthTerm(int x_zero_point_int, int y_zero_point_int, int reduction_dim_size) { + int32_t scalar_term = x_zero_point_int * y_zero_point_int * reduction_dim_size; + return MakeConstantScalar(DataType::Int(32), scalar_term); +} + +Expr BatchMatmulCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, + const Expr& term4) { + auto data1_term = Subtract(term1, term2); + auto data2_term = Subtract(term4, term3); + return Add(data1_term, data2_term); +} + +/* + * \brief Forward rewrite the qnn batch_matmul op. + * \param attrs The QNN batch_matmul attrs. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for qnn batch_matmul op. + * \note Lowering of the qnn.batch_matmul operator + * A quantized tensor is represented in following manner + * A = scale_a x (QA - zp_A) + * where QA is quantized tensor, scale_a and zp_A are quantization + * params. + * + * Quantized batch_matmul multiplies two quantized tensors and returns a + * quantized tensor of default dtype of int32, with scale equaling to the + * product of scales of input tensors, and a zero point of zero. + * + * The lowering for asymmetric quantized batch_matmul looks similar to + * quantized conv2d and dense and originally was discussed here: + * https://discuss.tvm.apache.org/t/tf-lite-quantized-conv2d-operator-conversion/2651/7 + * + * The computation gets unrolled into following 4 terms + * C(m, n) = Sigma(k) (X(m, k) * Y(n, k)) + * + * RHS becomes + * Sigma(k) ([QX(m, k) - zp_x] * [QY(n, k) - zp_y]) + * + * Unrolling leads to following sequence + * Sigma(k) QX(m, k) * QX(n, k) // Term1 + * - Sigma(k) zp_y * QX(m, k) // Term2 + * - Sigma(k) zp_x * QY(n, k) // Term3 + * - Sigma(k) * zp_x * zp_y // Term4 + * + * Term4 can be computed at compile time, everything else depending on the + * input type. + */ +Expr QnnBatchMatmulCanonicalize(const Attrs& attrs, const Array& new_args, + const Array& arg_types) { + ICHECK_EQ(new_args.size(), 6); + Expr quantized_x = new_args[0]; + Expr quantized_y = new_args[1]; + Expr x_zero_point = new_args[2]; + Expr y_zero_point = new_args[3]; + + const auto in_shape = get_shape(arg_types[0]); + const int reduction_dim_size = get_const_int(in_shape[2]); + + const auto y_shape = get_shape(arg_types[1]); + const int broadcast_dim_size = get_const_int(y_shape[1]); + + const auto* qnn_batch_matmul_attrs = attrs.as(); + + // Extract the integer zero points. + auto y_zero_point_int = GetScalarFromConstant(y_zero_point); + auto x_zero_point_int = GetScalarFromConstant(x_zero_point); + + // Get all the terms as described in the comments. + auto term1 = BatchMatmulFirstTerm(quantized_x, quantized_y, qnn_batch_matmul_attrs); + auto term2 = BatchMatmulSecondTerm(quantized_x, y_zero_point); + auto term3 = BatchMatmulThirdTerm(quantized_y, x_zero_point, broadcast_dim_size); + auto term4 = BatchMatmulFourthTerm(x_zero_point_int, y_zero_point_int, reduction_dim_size); + + // Combine those 4 terms depending on the zero points to get the best lowering. + if (x_zero_point_int == 0 && y_zero_point_int == 0) { + // term 2, 3 and 4 become zero. + return term1; + } else if (x_zero_point_int == 0 && y_zero_point_int != 0) { + // term 3 and term 4 become zero. + return Subtract(term1, term2); + } else if (x_zero_point_int != 0 && y_zero_point_int == 0) { + // term 2 and term 4 become zero. + return Subtract(term1, term3); + } else { + return BatchMatmulCombineTerms(term1, term2, term3, term4); + } +} + +RELAY_REGISTER_OP("qnn.batch_matmul") + .describe(R"code(Applies a linear transformation: :math:`Z = XY`. +- **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` +- **weight**: quantized(int8, unit8) `(units, input_dim)` +- **out**: quantized(int32) `(x1, x2, ..., xn, units)`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(6) + .add_argument("x", "quantized 2D Tensor", "First input data.") + .add_argument("y", "quantized 2D Tensor", "Second input data.") + .add_argument("x_scale", "Tensor", "The quantization scale of the x input tensor.") + .add_argument("x_zero_point", "Tensor", "The quantization zero_point of the x input tensor.") + .add_argument("y_scale", "Tensor", "The quantization scale of the y input tensor.") + .add_argument("y_zero_point", "Tensor", "The quantization zero_point of the y input tensor.") + .set_support_level(11) + .add_type_rel("QBatchMatmul", QnnBatchMatmulRel) + .set_attr("TNonComputational", true) + .set_attr("FTVMQnnCanonicalize", QnnBatchMatmulCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.batch_matmul").set_body_typed(MakeQuantizedBatchMatmul); + +} // namespace qnn +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_qnn_batch_matmul.py b/tests/python/relay/test_op_qnn_batch_matmul.py new file mode 100644 index 000000000000..91648aca3dbc --- /dev/null +++ b/tests/python/relay/test_op_qnn_batch_matmul.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_executor +from tvm.relay.testing.temp_op_attr import TempOpAttr + +# We use llvm target for testing functionality. `llvm` points to an older Intel +# generation machine, that legalizes to a simple lowering. Therefore, the +# legalization is overwritten such that it can be skipped and we use the +# QNNCanonicalizeOps lowering for the testing. +def legalize_qnn_batch_matmul(attrs, inputs, types): + return None + + +def make_requantize_params(input_scale, output_scale, output_zero_point, out_dtype): + config = { + "input_scale": input_scale, + "output_scale": output_scale, + "output_zero_point": output_zero_point, + "out_dtype": out_dtype, + } + return config + + +def make_configuration( + quantized_x, + quantized_y, + dtype, + x_shape, + y_shape, + x_zero_point, + y_zero_point, + x_scale, + y_scale, + output, + out_dtype="int32", + requantize=None, +): + config = { + "quantized_x": quantized_x, + "quantized_y": quantized_y, + "dtype": dtype, + "x_shape": x_shape, + "y_shape": y_shape, + "x_zero_point": x_zero_point, + "y_zero_point": y_zero_point, + "x_scale": x_scale, + "y_scale": y_scale, + "output": output, + "out_dtype": out_dtype, + "requantize": requantize, + } + return config + + +def make_int_configuration( + xzero_point_zero=True, yzero_point_zero=True, requantize_output=False, per_channel=False +): + x_shape, y_shape, output_shape = (1, 4, 5), (1, 3, 5), (1, 4, 3) + if xzero_point_zero == True: + x_zero_point = 0 + else: + x_zero_point = -123 + + if yzero_point_zero == True: + y_zero_point = 0 + else: + y_zero_point = -123 + + in_dtype = "int8" + out_dtype = "int32" if not requantize_output else "int8" + quantized_x_np = ( + np.array( + [ + 1, + 3, + 5, + 7, + 9, # sum = 25 + 11, + 13, + 15, + -19, + -21, # sum = -1 + 1, + 3, + 5, + 7, + 9, # sum = 25 + 11, + 13, + -17, + 17, + -21, + ] + ) # sum = 3 + .astype(in_dtype) + .reshape(x_shape) + ) + quantized_y_np = ( + np.array([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 1, 3, 5, 7, 9]) + .astype(in_dtype) + .reshape(y_shape) + ) + x_scale = 0.5 + y_scale = 0.5 + output_scale = 2.0 + + if requantize_output: + assert xzero_point_zero is True + assert yzero_point_zero is True + output = np.array([20, 51, 20, -26, -27, -26, 20, 51, 20, -14, -10, -14]) + elif xzero_point_zero is False and yzero_point_zero is False: + output = np.array( + [81960, 88360, 81960, 78400, 84540, 78400, 81960, 88360, 81960, 78984, 85164, 78984] + ) + elif xzero_point_zero is True and yzero_point_zero is False: + output = np.array([3240, 3490, 3240, -320, -330, -320, 3240, 3490, 3240, 264, 294, 264]) + elif xzero_point_zero is False and yzero_point_zero is True: + output = np.array([3240, 9640, 3240, 2878, 9018, 2878, 3240, 9640, 3240, 2970, 9150, 2970]) + else: + output = np.array([165, 415, 165, -197, -207, -197, 165, 415, 165, -105, -75, -105]) + + requant_params = ( + make_requantize_params(x_scale * y_scale, output_scale, -1, "int8") + if requantize_output + else None + ) + + output = output.astype(out_dtype).reshape(output_shape) + return make_configuration( + quantized_x=quantized_x_np, + quantized_y=quantized_y_np, + dtype=in_dtype, + x_shape=x_shape, + y_shape=y_shape, + x_zero_point=x_zero_point, + y_zero_point=y_zero_point, + x_scale=x_scale, + y_scale=y_scale, + output=output, + requantize=requant_params, + ) + + +def qnn_batch_matmul_driver(test_configuration): + in_dtype = test_configuration["dtype"] + out_dtype = test_configuration["out_dtype"] + quantized_x_name = "quantized_x" + quantized_y_name = "quantized_y" + expected_out_dtype = test_configuration["out_dtype"] + quantized_x = relay.var(quantized_x_name, shape=test_configuration["x_shape"], dtype=in_dtype) + quantized_y = relay.var(quantized_y_name, shape=test_configuration["y_shape"], dtype=in_dtype) + mod = relay.qnn.op.batch_matmul( + quantized_x, + quantized_y, + relay.const(test_configuration["x_zero_point"], "int32"), + relay.const(test_configuration["y_zero_point"], "int32"), + relay.const(test_configuration["x_scale"], "float32"), + relay.const(test_configuration["y_scale"], "float32"), + ) + if test_configuration["requantize"] is not None: + requantize_config = test_configuration["requantize"] + mod = relay.qnn.op.requantize( + mod, + input_scale=relay.const(requantize_config["input_scale"], "float32"), + input_zero_point=relay.const(0, "int32"), + output_scale=relay.const(requantize_config["output_scale"], "float32"), + output_zero_point=relay.const(requantize_config["output_zero_point"], "int32"), + out_dtype=requantize_config["out_dtype"], + ) + expected_out_dtype = requantize_config["out_dtype"] + + mod = relay.Function(relay.analysis.free_vars(mod), mod) + mod = tvm.IRModule.from_expr(mod) + mod = relay.transform.InferType()(mod) + mod = relay.qnn.transform.CanonicalizeOps()(mod) + with tvm.transform.PassContext(opt_level=2): + graph, lib, params = relay.build(mod, "llvm", params=None) + mod = graph_executor.create(graph, lib, device=tvm.cpu(0)) + mod.set_input(quantized_x_name, test_configuration[quantized_x_name]) + mod.set_input(quantized_y_name, test_configuration[quantized_y_name]) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).numpy() + np.testing.assert_equal(res, test_configuration["output"]) + assert res.dtype == expected_out_dtype + + +def test_qnn_batch_matmul_xzp0_yzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=True) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_xzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=True, yzero_point_zero=False) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_yzp0(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=True) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul(): + with TempOpAttr("qnn.batch_matmul", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int32_output_params = make_int_configuration(xzero_point_zero=False, yzero_point_zero=False) + qnn_batch_matmul_driver(int32_output_params) + + +def test_qnn_batch_matmul_with_requantized_output(): + with TempOpAttr("qnn.dense", "FTVMQnnLegalize", legalize_qnn_batch_matmul): + + int8_requantized_output_params = make_int_configuration(requantize_output=True) + qnn_batch_matmul_driver(int8_requantized_output_params) + + +if __name__ == "__main__": + test_qnn_batch_matmul_xzp0_yzp0() + test_qnn_batch_matmul_xzp0() + test_qnn_batch_matmul_yzp0() + test_qnn_batch_matmul() + test_qnn_batch_matmul_with_requantized_output()