Skip to content

Commit

Permalink
[ONNX] QGemm support (apache#13747)
Browse files Browse the repository at this point in the history
Co-authored-by: cheng.wen <wen.cheng@intellif.com>
  • Loading branch information
2 people authored and fzi-peccia committed Mar 27, 2023
1 parent 3d42755 commit 427b548
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 0 deletions.
85 changes: 85 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5091,6 +5091,90 @@ def _impl_v10(cls, inputs, attr, params):
return out


class QGemm(OnnxOpConverter):
"""Operator converter for QGemm."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QGemm

a = inputs[0]
a_scale = get_scalar(inputs[1], params)
a_zp = get_scalar(inputs[2], params, "int32")

b = inputs[3]
# must be a scalar or 1D tensor which means a per-tensor or per-column quantization
# If 1-D tensor, number of elements should be equal to columns elements of input B
b_scale = get_scalar_or_1d_tensor(inputs[4], params)
b_zp = get_scalar_or_1d_tensor(inputs[5], params, "int32")

# note that if optional and not provided then value will be None.
C = inputs[6]
# must be null or a scalar or 1D tensor of size 1
y_scale = inputs[7]
# must be null or a scalar or 1D tensor of size 1
y_zp = get_scalar(inputs[8], params, "int32")

assert len(infer_shape(a)) == 2
assert len(infer_shape(b)) == 2
# zero point and scale of input b should have same shape size
assert infer_shape(b_scale) == infer_shape(b_zp)

alpha = float(attr.get("alpha", 1.0))
transA = int(attr.get("transA", 0))
transB = int(attr.get("transB", 0))

# get number of channels
channels = infer_channels(b, not transB)
a_dtype = infer_type(a).checked_type.dtype

if transA:
a = _op.transpose(a, axes=(1, 0))
if not transB:
b = _op.transpose(b, axes=(1, 0))

result = _qnn.op.dense(
a,
b,
a_zp,
b_zp,
a_scale,
b_scale,
channels,
)

if C:
result = _op.add(result, C)

requantize_scale = _op.multiply(a_scale, b_scale)
if alpha != 1.0:
requantize_scale *= _expr.const(alpha, dtype="float32")
requantize_zp = _op.const(0, dtype="int32")

if y_scale:
# requantize requires y_scale to be constant,
# if y_scale is not constant, doing dequantize -> quantize
if isinstance(y_scale, _expr.Constant):
y = _qnn.op.requantize(
result,
requantize_scale,
requantize_zp,
y_scale,
y_zp,
axis=-1,
rounding="TONEAREST",
out_dtype=a_dtype,
)
else:
result_deq = _qnn.op.dequantize(result, requantize_scale, requantize_zp, axis=0)

y = _qnn.op.quantize(result_deq, y_scale, y_zp, axis=0, out_dtype=a_dtype)
else:
y = _op.multiply(_op.cast(result, "float32"), requantize_scale)

return y


class QLinearAdd(OnnxOpConverter):
"""Operator converter for QLinearAdd from Microsoft onnxruntime contrib opset."""

Expand Down Expand Up @@ -6337,6 +6421,7 @@ def _get_convert_map(opset):
"DequantizeLinear": DequantizeLinear.get_converter(opset),
"DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
"ReverseSequence": ReverseSequence.get_converter(opset),
"QGemm": QGemm.get_converter(opset),
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearConcat": QLinearConcat.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
Expand Down
168 changes: 168 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6213,6 +6213,174 @@ def verify_skiplayernormalization(input_, skip, gamma, beta, bias):
verify_skiplayernormalization(input_array, skip, gamma, beta, bias)


@tvm.testing.known_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_qgemm(target, dev):
"""test_qgemm"""

def verify_qgemm(
a_shape,
b_shape,
y_shape,
C=False,
y_zp=False,
b_per_tensor_quantization=False,
alpha=1.0,
transA=0,
transB=1,
):
a_array = np.random.randint(low=0, high=255, size=a_shape).astype("uint8")
b_array = np.random.uniform(low=0, high=255, size=b_shape).astype("uint8")

input_nodes = [
helper.make_tensor_value_info("a", TensorProto.UINT8, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.UINT8, list(b_shape)),
]

initializer = [
helper.make_tensor("a_scale", TensorProto.FLOAT, (), [np.random.rand()]),
helper.make_tensor("a_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]),
]

input_names = [
"a",
"a_scale",
"a_zero_point",
"b",
"b_scale",
"b_zero_point",
]
input_values = [a_array, b_array]

if b_per_tensor_quantization:
initializer.append(
helper.make_tensor("b_scale", TensorProto.FLOAT, (), [np.random.rand()])
)
initializer.append(
helper.make_tensor(
"b_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]
)
)
else: # per_colume_quantization
shape_value = b_shape[0] if transB else b_shape[1]
b_scale_array = np.random.random(shape_value).astype("float32")
w_zero_point_array = np.random.randint(0, 255, size=shape_value).astype("uint8")
initializer.append(
helper.make_tensor(
"b_scale", TensorProto.FLOAT, list(b_scale_array.shape), b_scale_array
)
)
initializer.append(
helper.make_tensor(
"b_zero_point",
TensorProto.UINT8,
list(w_zero_point_array.shape),
w_zero_point_array,
)
)

output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, list(y_shape))

if C is True:
C_shape = (b_shape[0] if transB else b_shape[1],)
C_array = np.random.randint(low=0, high=65536, size=C_shape).astype("int32")
input_nodes.append(helper.make_tensor_value_info("C", TensorProto.INT32, list(C_shape)))
input_names.append("C")
input_values.append(C_array)

if y_zp is True:
input_names.append("y_scale")
initializer.append(
helper.make_tensor("y_scale", TensorProto.FLOAT, (), [np.random.rand()])
)

input_names.append("y_zero_point")
initializer.append(
helper.make_tensor(
"y_zero_point", TensorProto.UINT8, (), [np.random.randint(0, 255)]
)
)

output_tensor = helper.make_tensor_value_info(
"output", TensorProto.UINT8, list(y_shape)
)

kwargs = {}
kwargs["alpha"] = alpha
kwargs["transA"] = transA
kwargs["transB"] = transB

node = helper.make_node(
"QGemm",
inputs=input_names,
outputs=["output"],
domain="com.microsoft",
# Default values for other attributes:
**kwargs,
)

graph = helper.make_graph(
[node],
"QGemm",
inputs=input_nodes,
outputs=[output_tensor],
initializer=initializer,
)
model = helper.make_model(
graph,
producer_name="QGemm",
opset_imports=[
onnx.helper.make_opsetid("com.microsoft", 1),
],
)

verify_with_ort_with_inputs(model, input_values, target=target, dev=dev)

# B per tensor quantization
verify_qgemm(
(20, 30),
(50, 30),
(20, 50),
True,
True,
True,
)

# B per column quantization
verify_qgemm(
(20, 30),
(50, 30),
(20, 50),
True,
True,
False,
)

# test alpha
verify_qgemm(
(20, 30),
(50, 30),
(20, 50),
True,
True,
True,
0.5,
)

# test transpose A
verify_qgemm(
(20, 50),
(20, 80),
(50, 80),
True,
True,
True,
0.5,
1,
0,
)


@tvm.testing.known_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_qlinearconv(target, dev):
Expand Down

0 comments on commit 427b548

Please sign in to comment.