Skip to content

Commit

Permalink
[AMP] Disallow fp16 conversion for summation-like ops (#8810)
Browse files Browse the repository at this point in the history
* [AMP] Disallow fp16 conversion for summation-like ops

* test only structural equality
  • Loading branch information
masahi authored Aug 26, 2021
1 parent 3d81489 commit f4f525d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
15 changes: 7 additions & 8 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@
"divide",
"nn.bias_add",
"nn.batch_norm",
"sum",
"mean",
"sqrt",
"shape_of",
# Simple activations
Expand All @@ -107,15 +105,9 @@
# "nn.global_max_pool1d", # does not exist yet
"nn.global_max_pool2d",
# "nn.global_max_pool3d", # does not exist yet
# "nn.global_avg_pool1d", # does not exist yet
"nn.global_avg_pool2d",
# "nn.global_avg_pool3d", # does not exist yet
"nn.adaptive_max_pool1d",
"nn.adaptive_max_pool2d",
"nn.adaptive_max_pool3d",
"nn.adaptive_avg_pool1d",
"nn.adaptive_avg_pool2d",
"nn.adaptive_avg_pool3d",
]
DEFAULT_NEVER_LIST = [
# In general if |f(x)| >> |x| for expected inputs then put the op here.
Expand All @@ -131,6 +123,13 @@
# Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number
# not representable in fp16.
"arange",
# Ops that could involve a large summation are not allowed in fp16.
"nn.global_avg_pool2d",
"nn.adaptive_avg_pool1d",
"nn.adaptive_avg_pool2d",
"nn.adaptive_avg_pool3d",
"sum",
"mean",
]


Expand Down
31 changes: 22 additions & 9 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,23 +221,36 @@ def test_do_not_convert_softmax():
b = relay.nn.softmax(a)
mod = tvm.IRModule.from_expr(b)
mod = tvm.relay.transform.InferType()(mod)

mod_params = {
"a": np.random.uniform(-1, 1, size=shape).astype("float32"),
}
output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0)
assert tvm.ir.structural_equal(mod, output_mod)
out_mod = ToMixedPrecision("float16")(mod)
orig_mod = tvm.relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(orig_mod, out_mod)


def test_do_not_convert_arange():
"""Arange is a red listed operation and therefore should never be fp16."""
dtype = "float32"
arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype))
mod = tvm.IRModule.from_expr(arange)
mod = tvm.relay.transform.InferType()(mod)
out_mod = ToMixedPrecision("float16")(mod)
orig_mod = tvm.relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(orig_mod, out_mod)

output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0)
assert tvm.ir.structural_equal(mod, output_mod)

def test_do_not_convert_summation():
"""Ops that could involve a large summation are not allowed in fp16."""
shape = [1, 3, 16, 16]
a = relay.var("a", shape=shape)
ops = [
relay.sum,
relay.mean,
relay.nn.global_avg_pool2d,
lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)),
]
for op in ops:
mod = tvm.IRModule.from_expr(op(a))
out_mod = ToMixedPrecision("float16")(mod)
orig_mod = tvm.relay.transform.InferType()(mod)
assert tvm.ir.structural_equal(orig_mod, out_mod)


def test_green_gray_propagates_simple():
Expand Down

0 comments on commit f4f525d

Please sign in to comment.