Skip to content

Commit

Permalink
Add batch_matmul convertion to FQ2I pass
Browse files Browse the repository at this point in the history
  • Loading branch information
elvin-n committed Aug 3, 2021
1 parent 09e234d commit 5ff9a95
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ def dense(expr, type_map):
return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)]


@register_fake_quantization_to_integer("nn.batch_matmul")
def batch_matmul(expr, type_map):
"""Rewrite a batch_matmul op"""
x, y = expr.args
x_t = type_map[x]
y_t = type_map[y]
matmul_scale = fold_constant(x_t.scale * y_t.scale)
matmul_zp = relay.const(0)
out = relay.qnn.op.batch_matmul(
x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale
)
return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)]


@register_fake_quantization_to_integer("concatenate")
def concat(expr, type_map):
"""Rewrite a concat op"""
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_pass_fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ def test_fake_quantize_dense():
compare_fq_to_int(op, [x_np, w_np])


def test_fake_quantize_batch_matmul():
for out_dtype in ["int8", "uint8"]:
x = relay.var("x", shape=[1, 128, 64], dtype="int8")
w = relay.var("w", shape=[1, 256, 64], dtype="int8")
one = relay.const(1.0)
zero = relay.const(0)

op = relay.op.nn.batch_matmul(
relay.qnn.op.dequantize(x, relay.const(2.0), zero),
relay.qnn.op.dequantize(w, relay.const(0.5), zero),
)
op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype)

x_np = np.random.randint(-128, 127, size=[1, 128, 64], dtype="int8")
w_np = np.random.randint(-128, 127, size=[1, 256, 64], dtype="int8")

compare_fq_to_int(op, [x_np, w_np])

def test_fake_transpose_quantize_conv():
x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8")
w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8")
Expand Down

0 comments on commit 5ff9a95

Please sign in to comment.