Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] Turn off accumulation data types for mixed precision pass #8341

Merged
merged 9 commits into from
Jun 29, 2021
15 changes: 4 additions & 11 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"nn.conv2d_transpose",
"nn.conv3d_transpose",
"nn.dense",
# "nn.batch_matmul", # Handled by a special case
"nn.batch_matmul",
]
DEFAULT_FOLLOW_LIST = [
# These ops add new data or change shape
Expand Down Expand Up @@ -162,7 +162,9 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) ->
# Some discussion here about making this better is here:
# https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo
if hasattr(call_node.attrs, "out_dtype"):
return ["float32", mixed_precision_type]
# TODO (AndrewZhaoLuo): evaluate consistent support for mixed_type accumulators
# return ["float32", mixed_precision_type]
return [mixed_precision_type, mixed_precision_type]

# [accumulation_dtype, output_dtype] for the operations
return [mixed_precision_type, mixed_precision_type]
Expand All @@ -184,12 +186,3 @@ def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List:
@register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST)
def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List:
return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type)


@register_mixed_precision_conversion("nn.batch_matmul")
def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List:
# TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well.
# Batched matmul has inconsistent support for mixed precision operations.
# Many schedules ignore the out_dtype attribute which leads to errors when
# input types do not match the out_dtype. Therefore, accumulate to output_dtype.
return [MIXED_PRECISION_ALWAYS, "float16", "float16"]
9 changes: 3 additions & 6 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
"""
import numpy as np
import tvm
from tvm import te
import tvm.testing
import tvm.topi.testing
from tvm import relay
from tvm import relay, te, topi
from tvm.relay import transform
from tvm.relay.testing import run_infer_type
from tvm import topi
import tvm.topi.testing
import tvm.testing


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -608,7 +605,7 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, device=dev, target=target)
out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-4, atol=1e-5)
tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-6, atol=1e-6)

_verify((10, 5))
_verify((10, 5, 2, 2))
Expand Down
84 changes: 39 additions & 45 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def verify_mixed_precision_output_close(
result_fp32 = run_module(mod, mod_params)
fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
result_fp16 = run_module(fp16_mod, mod_params)

# Ensure the results are close
for fp32, fp16 in zip(result_fp32, result_fp16):
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
Expand All @@ -60,7 +61,9 @@ def test_lstm():

Has internal functions and let statements the pass must work on.
"""
units = 3
# TODO(AndrewZhaoLuo): investigate why non-even units cause failure in codegen for CUDA
# See discussion here: https://github.com/apache/tvm/issues/8294#issuecomment-866190408
units = 4
iterations = 5
mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units)

Expand Down Expand Up @@ -118,16 +121,13 @@ def test_convert_single_conv():
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)

expected_mod = tvm.IRModule.from_expr(
relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
),
"float16",
)
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)

Expand Down Expand Up @@ -156,16 +156,13 @@ def test_convert_single_conv_fp64():
# Note we still accumulate to FP32 by default, a user would need to overwrite default
# behavior to make this make more sense.
expected_mod = tvm.IRModule.from_expr(
relay.cast(
relay.nn.conv2d(
relay.cast(data, "float64"),
relay.cast(weight, "float64"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
),
"float64",
)
relay.nn.conv2d(
relay.cast(data, "float64"),
relay.cast(weight, "float64"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float64",
),
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)

Expand Down Expand Up @@ -198,15 +195,12 @@ def test_convert_conv_bn():
"moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
"moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
}
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.025, rtol=0.01)

# Creating expected module
data = relay.cast(relay.var("data", shape=data_shape), "float16")
weight = relay.cast(relay.var("weight", shape=weight_shape), "float16")
conv = relay.cast(
relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"),
"float16",
)
conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float16")

bn_shape = [5]
gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16")
Expand Down Expand Up @@ -254,17 +248,14 @@ def test_green_gray_propagates_simple():
"data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
"weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
}
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01)

conv_expr = relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
),
"float16",
conv_expr = relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
)
expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr)
expected_mod = tvm.relay.transform.InferType()(expected_mod)
Expand Down Expand Up @@ -316,12 +307,15 @@ def test_green_red_not_use_extraneous_cast():
fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)

# Construct expected structure
conv = relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float32",
conv = relay.cast(
relay.nn.conv2d(
relay.cast(data, "float16"),
relay.cast(weight, "float16"),
strides=(1, 1),
padding=(1, 1),
out_dtype="float16",
),
"float32",
)
result = relay.nn.softmax(conv)
expected_mod = tvm.IRModule.from_expr(result)
Expand Down Expand Up @@ -380,12 +374,12 @@ def test_let_statement_simple():
r2 = var2 + var2
let2 = relay.Let(
var2,
relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"),
relay.nn.dense(r1, weight, units=20, out_dtype="float16"),
r2,
)
let1 = relay.Let(
var1,
relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"),
relay.nn.dense(data, weight, units=20, out_dtype="float16"),
let2,
)
expected_mod = tvm.IRModule.from_expr(let1)
Expand All @@ -410,7 +404,7 @@ def test_where_simple():
# Create expected module
data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16")
a = relay.nn.dense(data, weight, units=20, out_dtype="float16")
b = relay.where(data, a, a)
expected_mod = tvm.IRModule.from_expr(b)
expected_mod = InferType()(expected_mod)
Expand Down