diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 6aa3ac09cfee..6f8ecb970221 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -40,7 +40,7 @@ "nn.conv2d_transpose", "nn.conv3d_transpose", "nn.dense", - # "nn.batch_matmul", # Handled by a special case + "nn.batch_matmul", ] DEFAULT_FOLLOW_LIST = [ # These ops add new data or change shape @@ -162,7 +162,9 @@ def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> # Some discussion here about making this better is here: # https://discuss.tvm.apache.org/t/rfc-relay-fp32-fp16-model-support/9994/4?u=andrewzhaoluo if hasattr(call_node.attrs, "out_dtype"): - return ["float32", mixed_precision_type] + # TODO (AndrewZhaoLuo): evaluate consistent support for mixed_type accumulators + # return ["float32", mixed_precision_type] + return [mixed_precision_type, mixed_precision_type] # [accumulation_dtype, output_dtype] for the operations return [mixed_precision_type, mixed_precision_type] @@ -184,12 +186,3 @@ def generic_follow_op(call_node: relay.Call, mixed_precision_type: str) -> List: @register_func_to_op_list(list_ops=DEFAULT_NEVER_LIST) def generic_never_op(call_node: relay.Call, mixed_precision_type: str) -> List: return [MIXED_PRECISION_NEVER] + get_generic_out_dtypes(call_node, mixed_precision_type) - - -@register_mixed_precision_conversion("nn.batch_matmul") -def nn_batch_matmul(call_node: relay.Call, mixed_precision_type: str) -> List: - # TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well. - # Batched matmul has inconsistent support for mixed precision operations. - # Many schedules ignore the out_dtype attribute which leads to errors when - # input types do not match the out_dtype. Therefore, accumulate to output_dtype. - return [MIXED_PRECISION_ALWAYS, "float16", "float16"] diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0eddd965c661..24f0ed6642b5 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -18,14 +18,11 @@ """ import numpy as np import tvm -from tvm import te +import tvm.testing import tvm.topi.testing -from tvm import relay +from tvm import relay, te, topi from tvm.relay import transform from tvm.relay.testing import run_infer_type -from tvm import topi -import tvm.topi.testing -import tvm.testing @tvm.testing.uses_gpu @@ -608,7 +605,7 @@ def _verify(prediction_shape, reduction="mean", ignore_index=-100, dtype="float3 for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, device=dev, target=target) out_relay = intrp.evaluate(func)(predictions_np, targets_np, weights_np) - tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np, rtol=1e-6, atol=1e-6) _verify((10, 5)) _verify((10, 5, 2, 2)) diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index caccd52d60c2..7a3fbfafc089 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -48,6 +48,7 @@ def verify_mixed_precision_output_close( result_fp32 = run_module(mod, mod_params) fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod) result_fp16 = run_module(fp16_mod, mod_params) + # Ensure the results are close for fp32, fp16 in zip(result_fp32, result_fp16): np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol) @@ -60,7 +61,9 @@ def test_lstm(): Has internal functions and let statements the pass must work on. """ - units = 3 + # TODO(AndrewZhaoLuo): investigate why non-even units cause failure in codegen for CUDA + # See discussion here: https://github.com/apache/tvm/issues/8294#issuecomment-866190408 + units = 4 iterations = 5 mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units) @@ -118,16 +121,13 @@ def test_convert_single_conv(): fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) expected_mod = tvm.IRModule.from_expr( - relay.cast( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float16", - ) + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", + ), ) expected_mod = tvm.relay.transform.InferType()(expected_mod) @@ -156,16 +156,13 @@ def test_convert_single_conv_fp64(): # Note we still accumulate to FP32 by default, a user would need to overwrite default # behavior to make this make more sense. expected_mod = tvm.IRModule.from_expr( - relay.cast( - relay.nn.conv2d( - relay.cast(data, "float64"), - relay.cast(weight, "float64"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float64", - ) + relay.nn.conv2d( + relay.cast(data, "float64"), + relay.cast(weight, "float64"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float64", + ), ) expected_mod = tvm.relay.transform.InferType()(expected_mod) @@ -198,15 +195,12 @@ def test_convert_conv_bn(): "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.025, rtol=0.01) # Creating expected module data = relay.cast(relay.var("data", shape=data_shape), "float16") weight = relay.cast(relay.var("weight", shape=weight_shape), "float16") - conv = relay.cast( - relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"), - "float16", - ) + conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float16") bn_shape = [5] gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16") @@ -254,17 +248,14 @@ def test_green_gray_propagates_simple(): "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"), "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"), } - fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) + fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=0.01) - conv_expr = relay.cast( - relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", - ), - "float16", + conv_expr = relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", ) expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr) expected_mod = tvm.relay.transform.InferType()(expected_mod) @@ -316,12 +307,15 @@ def test_green_red_not_use_extraneous_cast(): fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3) # Construct expected structure - conv = relay.nn.conv2d( - relay.cast(data, "float16"), - relay.cast(weight, "float16"), - strides=(1, 1), - padding=(1, 1), - out_dtype="float32", + conv = relay.cast( + relay.nn.conv2d( + relay.cast(data, "float16"), + relay.cast(weight, "float16"), + strides=(1, 1), + padding=(1, 1), + out_dtype="float16", + ), + "float32", ) result = relay.nn.softmax(conv) expected_mod = tvm.IRModule.from_expr(result) @@ -380,12 +374,12 @@ def test_let_statement_simple(): r2 = var2 + var2 let2 = relay.Let( var2, - relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"), + relay.nn.dense(r1, weight, units=20, out_dtype="float16"), r2, ) let1 = relay.Let( var1, - relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"), + relay.nn.dense(data, weight, units=20, out_dtype="float16"), let2, ) expected_mod = tvm.IRModule.from_expr(let1) @@ -410,7 +404,7 @@ def test_where_simple(): # Create expected module data = relay.cast(relay.var("data", shape=[1, 20]), "float16") weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16") - a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16") + a = relay.nn.dense(data, weight, units=20, out_dtype="float16") b = relay.where(data, a, a) expected_mod = tvm.IRModule.from_expr(b) expected_mod = InferType()(expected_mod)