diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 783204fb700f0..05266052c41c6 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -139,6 +139,20 @@ def dense(expr, type_map): return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] +@register_fake_quantization_to_integer("nn.batch_matmul") +def batch_matmul(expr, type_map): + """Rewrite a batch_matmul op""" + x, y = expr.args + x_t = type_map[x] + y_t = type_map[y] + matmul_scale = fold_constant(x_t.scale * y_t.scale) + matmul_zp = relay.const(0) + out = relay.qnn.op.batch_matmul( + x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale + ) + return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)] + + @register_fake_quantization_to_integer("concatenate") def concat(expr, type_map): """Rewrite a concat op""" diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 1e7d749ff418a..cee490b1ab43b 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -79,6 +79,24 @@ def test_fake_quantize_dense(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_batch_matmul(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 128, 64], dtype="int8") + w = relay.var("w", shape=[1, 256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.batch_matmul( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[1, 128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[1, 256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) + def test_fake_transpose_quantize_conv(): x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8")