diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py index 8a7c4cbfbbd6..31f5ac6e71b1 100644 --- a/tests/python/relay/test_pass_auto_quantize.py +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -307,6 +307,39 @@ def @main( verify_partition_fails(mod, params) +def test_left_shift_negative(): + data = relay.var("data", shape=(1, 16, 64, 64)) + weight = relay.const(np.full((16, 16, 3, 3), 256.0)) + conv2d = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=16) + relu = relay.nn.relu(conv2d) + + mod = tvm.IRModule.from_expr(relu) + + with tvm.transform.PassContext(opt_level=3): + with relay.quantize.qconfig( + calibrate_mode="global_scale", global_scale=8.0, skip_conv_layers=None + ): + qnn_mod = relay.quantize.quantize(mod) + + class OpFinder(relay.ExprVisitor): + def __init__(self, op_name): + super(OpFinder, self).__init__() + self._op_name = op_name + self.ops = list() + + def visit_call(self, call): + super().visit_call(call) + if call.op.name == self._op_name: + self.ops.append(call) + + opf = OpFinder("left_shift") + opf.visit(qnn_mod["main"]) + assert len(opf.ops) > 0, 'Broken case, can\'t find any "left_shift" operators.' + for left_shift_op in opf.ops: + shift_amount = left_shift_op.args[1].data.asnumpy() + assert shift_amount >= 0, "Shift amount must be non-negative." + + if __name__ == "__main__": test_mul_rewrite() test_batch_flatten_rewrite() @@ -320,3 +353,4 @@ def @main( test_unquantizable_prefix_partition() test_unquantizable_core_partition() test_unquantizable_suffix_partition() + test_left_shift_negative()