Skip to content

Commit

Permalink
Support QLinearAdd from onnx runtime com.microsoft contrib ops. (apac…
Browse files Browse the repository at this point in the history
…he#8305)

* support QLinearAdd

* fix comment line length

* use platform independent temp directory
  • Loading branch information
Matthew Brookhart authored and ylc committed Sep 29, 2021
1 parent f4e767c commit 9b42371
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 2 deletions.
44 changes: 43 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2973,6 +2973,8 @@ def _impl_v13(cls, inputs, attr, params):
data, scale, zp = inputs
out_dtype = infer_type(zp).checked_type.dtype
axis = attr.get("axis", 1)
if len(infer_shape(data)) < 2:
axis = 0
return _qnn.op.quantize(data, scale, _op.cast(zp, "int32"), axis, out_dtype)


Expand Down Expand Up @@ -3033,10 +3035,11 @@ def get_scalar(x, dtype="float32"):
weight = inputs[3]
w_scale = get_scalar(inputs[4])
w_zero_point = get_scalar(inputs[5], "int32")
y_scale = get_scalar(inputs[6])
y_scale = fold_constant(get_scalar(inputs[6]))
y_zero_point = get_scalar(inputs[7], "int32")

input_shape = infer_shape(data)

ndim = len(input_shape)
kernel_type = infer_type(weight)
kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]
Expand Down Expand Up @@ -3116,6 +3119,44 @@ def get_scalar(x, dtype="float32"):
return out


class QLinearAdd(OnnxOpConverter):
"""Operator converter for QLinearAdd from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
def get_scalar(x, dtype="float32"):
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)
rank = len(infer_shape(x))
assert rank <= 1, "QLinearConv scale and zero_point input must be scalars"
if rank == 1:
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)

a = inputs[0]
a_scale = get_scalar(inputs[1])
a_zero_point = get_scalar(inputs[2], "int32")
b = inputs[3]
b_scale = get_scalar(inputs[4])
b_zero_point = get_scalar(inputs[5], "int32")
c_scale = get_scalar(inputs[6])
c_zero_point = get_scalar(inputs[7], "int32")

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp
a = _qnn.op.dequantize(
inputs[0], a_scale, a_zero_point
) # , c_scale, c_zero_point, out_dtype = dtype)
b = _qnn.op.dequantize(
inputs[3], b_scale, b_zero_point
) # , c_scale, c_zero_point, out_dtype = dtype)
out = _op.add(a, b)
return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)


class BitShift(OnnxOpConverter):
"""Operator converter for NonZero"""

Expand Down Expand Up @@ -3343,6 +3384,7 @@ def _get_convert_map(opset):
"DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
"ReverseSequence": ReverseSequence.get_converter(opset),
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
}


Expand Down
62 changes: 61 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import re

import numpy as np
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_tvm_output(
def get_onnxruntime_output(model, inputs):
import onnxruntime.backend

rep = onnxruntime.backend.prepare(model, "CPU")
rep = onnxruntime.backend.prepare(model.SerializeToString(), "CPU")
if isinstance(inputs, list) and len(inputs) == 1:
inp = inputs[0]
else:
Expand Down Expand Up @@ -149,6 +150,7 @@ def verify_with_ort_with_inputs(
):
if opset is not None:
model.opset_import[0].version = opset

ort_out = get_onnxruntime_output(model, inputs)

if targets is None:
Expand Down Expand Up @@ -4755,6 +4757,64 @@ def repeat(N, D):
)


def verify_qlinearadd(a_shape, b_shape, c_shape):

a_array = np.random.random(a_shape).astype("float32")
b_array = np.random.random(b_shape).astype("float32")

input_nodes = [
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
]
input_names = [
"a",
"b",
]
input_values = [a_array, b_array]

node = helper.make_node("QLinearAdd", inputs=input_names, outputs=["C"])

node = helper.make_node("Add", ["a", "b"], ["C"])
graph = helper.make_graph(
[node],
"qlinearadd_test",
inputs=input_nodes,
outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))],
)
model = helper.make_model(graph, producer_name="qlinearconv_test")
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType

class RandomDataReader(CalibrationDataReader):
def __init__(self, n=10):
self.data = iter(
[
{
"a": np.random.random(a_shape).astype("float32"),
"b": np.random.random(b_shape).astype("float32"),
}
for _ in range(n)
]
)

def get_next(self):
return next(self.data, None)

d = tvm.contrib.utils.tempdir()
model_fp32 = os.path.join(d.temp_dir, "model.onnx")
onnx.save_model(model, model_fp32)
model_quant = os.path.join(d.temp_dir, "model.quant.onnx")
quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader())
# opt_level=1 will cause error with qnn lowering
model = onnx.load(model_quant)
verify_with_ort_with_inputs(model, input_values, opt_level=2)


def test_qlinearadd():
verify_qlinearadd([4, 2], [4, 2], [4, 2])
verify_qlinearadd([4, 2], [2], [4, 2])
verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7])


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down

0 comments on commit 9b42371

Please sign in to comment.