Skip to content

Commit

Permalink
fix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew committed Jul 14, 2021
1 parent f961332 commit ba175c8
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def quantize(expr, type_map):
return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)]


def register_unary_identity(op_name, op):
def register_unary_identity(op_name):
def identity(expr, type_map):
assert len(expr.args) == 1
arg = expr.args[0]
Expand All @@ -66,13 +66,13 @@ def identity(expr, type_map):
return register_fake_quantization_to_integer(op_name, identity)


register_unary_identity("reshape", relay.op.reshape)
register_unary_identity("squeeze", relay.op.squeeze)
register_unary_identity("strided_slice", relay.op.strided_slice)
register_unary_identity("transpose", relay.op.transpose)
register_unary_identity("expand_dims", relay.op.expand_dims)
register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d)
register_unary_identity("nn.batch_flatten", relay.op.nn.batch_flatten)
register_unary_identity("reshape")
register_unary_identity("squeeze")
register_unary_identity("strided_slice")
register_unary_identity("transpose")
register_unary_identity("expand_dims")
register_unary_identity("nn.max_pool2d")
register_unary_identity("nn.batch_flatten")


@register_fake_quantization_to_integer("nn.avg_pool2d")
Expand Down Expand Up @@ -201,6 +201,7 @@ def clip(expr, type_map):

@register_fake_quantization_to_integer("nn.pad")
def pad(expr, type_map):
"""Rewite an nn.pad op"""
arg = expr.args[0]
t = type_map[arg]
pad_value = expr.args[1]
Expand All @@ -219,12 +220,12 @@ def pad(expr, type_map):
assert isinstance(pad_value, relay.expr.Constant)
pad_value = relay.qnn.op.quantize(pad_value, t.scale, t.zero_point)

z_p = fold_constant(t.zero_point)
out = relay.op.nn.pad(arg, pad_value=pad_value, **expr.attrs)
return [out, t]


def get_binary_types(expr, type_map):
"""Get Affine types of a binary op's inputs and unify them"""
##Support the case where one input is quantized and the other is a constant float
left = expr.args[0]
right = expr.args[1]
Expand Down Expand Up @@ -262,6 +263,7 @@ def get_binary_types(expr, type_map):


def register_binary_qnn(op_name, op):
"""Register a Binary Op that converts to QNN"""
def binary(expr, type_map):
left, right, left_t, right_t, out_t = get_binary_types(expr, type_map)
out = op(
Expand All @@ -280,12 +282,13 @@ def binary(expr, type_map):


# Use lambdas here to avoid a circular import problem
register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args))
register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args))
register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args))
register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args)) # pylint: disable=unnecessary-lambda
register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args)) # pylint: disable=unnecessary-lambda
register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args)) # pylint: disable=unnecessary-lambda


def register_binary_identity(op_name, op):
"""Register a binary op that works directly on int8"""
def binary(expr, type_map):
left, right, left_t, right_t, out_t = get_binary_types(expr, type_map)
if left_t != out_t:
Expand Down

0 comments on commit ba175c8

Please sign in to comment.