From 39d334142ba90fbf0067fa09458e78f9dafcbc95 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 23 Nov 2023 03:40:09 +0000 Subject: [PATCH 01/27] add type promotion table. --- paddle/fluid/eager/type_promotion_utils.h | 105 ++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 paddle/fluid/eager/type_promotion_utils.h diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h new file mode 100644 index 00000000000000..99a2875a6a54a4 --- /dev/null +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -0,0 +1,105 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" + +namespace egr { + +inline int DataTypeToNum(const phi::DataType& dtype) { + switch (dtype) { + case phi::DataType::UINT8: + return 0; + case phi::DataType::INT8: + return 1; + case phi::DataType::INT16: + return 2; + case phi::DataType::INT32: + return 3; + case phi::DataType::INT64: + return 4; + case phi::DataType::FLOAT16: + return 5; + case phi::DataType::FLOAT32: + return 6; + case phi::DataType::FLOAT64: + return 7; + case phi::DataType::COMPLEX64: + return 8; + case phi::DataType::COMPLEX128: + return 9; + case phi::DataType::BOOL: + return 10; + case phi::DataType::BFLOAT16: + return 11; + default: + PD_THROW("Invalid enum data type for type promote `", dtype, "`."); + } +} + +static inline bool is_support_float(phi::DataType dtype) { + if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::FLOAT32 || + dtype == phi::DataType::FLOAT64 || dtype == phi::DataType::BFLOAT16) { + return true; + } else { + return false; + } +} + +static inline bool is_support_int(phi::DataType dtype) { + if (dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64) { + return true; + } else { + return false; + } +} + +inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) { + constexpr auto u1 = phi::DataType::UINT8; + constexpr auto i1 = phi::DataType::INT8; + constexpr auto i2 = phi::DataType::INT16; + constexpr auto i4 = phi::DataType::INT32; + constexpr auto i8 = phi::DataType::INT64; + constexpr auto f2 = phi::DataType::FLOAT16; + constexpr auto f4 = phi::DataType::FLOAT32; + constexpr auto f8 = phi::DataType::FLOAT64; + constexpr auto c4 = phi::DataType::COMPLEX64; + constexpr auto c8 = phi::DataType::COMPLEX128; + constexpr auto b1 = phi::DataType::BOOL; + constexpr auto bf = phi::DataType::BFLOAT16; + + // this matrix has to be consistent with + // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we + // are not sure about the correct value for type promotion. + // clang-format off + static constexpr phi::DataType _promoteTypesLookup[ + 12][12] = { + /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, + }; + // clang-format on + return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; +} + +} // namespace egr From 76169122667e6cb4e1d3abe4af3b45304f9eaab9 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 23 Nov 2023 06:04:17 +0000 Subject: [PATCH 02/27] fix codestyle. --- paddle/fluid/eager/type_promotion_utils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 99a2875a6a54a4..1302f6de150cb3 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#pragma once #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/utils/global_utils.h" From f7992a7c635cbd36333cbcfe87ad027b1ec1231a Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 23 Nov 2023 07:57:51 +0000 Subject: [PATCH 03/27] add python table. --- python/paddle/base/__init__.py | 1 + python/paddle/base/type_promotion.py | 73 ++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 python/paddle/base/type_promotion.py diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 7e5ac9c1d92c44..19342a4ca1f8f6 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -126,6 +126,7 @@ HeterXpuTrainer, ) from .backward import append_backward +from . import type_promotion Tensor = LoDTensor enable_imperative = enable_dygraph diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py new file mode 100644 index 00000000000000..7125a6683bbbb2 --- /dev/null +++ b/python/paddle/base/type_promotion.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +u1 = paddle.uint8 +i1 = paddle.int8 +i2 = paddle.int16 +i4 = paddle.int32 +i8 = paddle.int64 +f2 = paddle.float16 +f4 = paddle.float32 +f8 = paddle.float64 +c4 = paddle.complex64 +c8 = paddle.complex128 +b1 = paddle.bool +bf = paddle.bfloat16 + + +Number = { + paddle.uint8: 0, + paddle.int8: 1, + paddle.int16: 2, + paddle.int32: 3, + paddle.int64: 4, + paddle.float16: 5, + paddle.float32: 6, + paddle.float64: 7, + paddle.complex64: 8, + paddle.complex128: 9, + paddle.bool: 10, + paddle.bfloat16: 11, +} + +promoteTypesLookup = [ + [u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf], + [i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf], + [i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf], + [i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf], + [i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf], + [f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4], + [f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4], + [f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8], + [c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4], + [c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8], + [u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf], + [bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf], +] + + +def get_result_dtype(x_dtype, y_dtype): + if x_dtype == y_dtype: + return x_dtype + else: + try: + return promoteTypesLookup[Number[x_dtype]][Number[y_dtype]] + except: + print( + "got unsupport dtype for type promotion: {} and {}.".format( + x_dtype, y_dtype + ) + ) From 70993d6831c0f7cc9f843288029932d5244149a9 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Fri, 24 Nov 2023 02:56:47 +0000 Subject: [PATCH 04/27] fix dtype. --- python/paddle/base/type_promotion.py | 51 ++++++++++++++-------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py index 7125a6683bbbb2..5adcb127e71bb0 100644 --- a/python/paddle/base/type_promotion.py +++ b/python/paddle/base/type_promotion.py @@ -11,36 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from paddle.framework import dtype -import paddle - -u1 = paddle.uint8 -i1 = paddle.int8 -i2 = paddle.int16 -i4 = paddle.int32 -i8 = paddle.int64 -f2 = paddle.float16 -f4 = paddle.float32 -f8 = paddle.float64 -c4 = paddle.complex64 -c8 = paddle.complex128 -b1 = paddle.bool -bf = paddle.bfloat16 +u1 = dtype.uint8 +i1 = dtype.int8 +i2 = dtype.int16 +i4 = dtype.int32 +i8 = dtype.int64 +f2 = dtype.float16 +f4 = dtype.float32 +f8 = dtype.float64 +c4 = dtype.complex64 +c8 = dtype.complex128 +b1 = dtype.bool +bf = dtype.bfloat16 Number = { - paddle.uint8: 0, - paddle.int8: 1, - paddle.int16: 2, - paddle.int32: 3, - paddle.int64: 4, - paddle.float16: 5, - paddle.float32: 6, - paddle.float64: 7, - paddle.complex64: 8, - paddle.complex128: 9, - paddle.bool: 10, - paddle.bfloat16: 11, + dtype.uint8: 0, + dtype.int8: 1, + dtype.int16: 2, + dtype.int32: 3, + dtype.int64: 4, + dtype.float16: 5, + dtype.float32: 6, + dtype.float64: 7, + dtype.complex64: 8, + dtype.complex128: 9, + dtype.bool: 10, + dtype.bfloat16: 11, } promoteTypesLookup = [ From 164daec4bdd40250b11652bc57e69be938a38538 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Fri, 24 Nov 2023 03:00:23 +0000 Subject: [PATCH 05/27] remove useless note --- paddle/fluid/eager/type_promotion_utils.h | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 1302f6de150cb3..00313512d2c465 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -79,12 +79,7 @@ inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) { constexpr auto b1 = phi::DataType::BOOL; constexpr auto bf = phi::DataType::BFLOAT16; - // this matrix has to be consistent with - // AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS undefined is used where we - // are not sure about the correct value for type promotion. - // clang-format off - static constexpr phi::DataType _promoteTypesLookup[ - 12][12] = { + static constexpr phi::DataType _promoteTypesLookup[12][12] = { /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, @@ -99,7 +94,7 @@ inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) { /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, }; - // clang-format on + return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; } From 863f139f7f2dfd9f88376de6a38acfa68d83babc Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Mon, 27 Nov 2023 03:15:37 +0000 Subject: [PATCH 06/27] fix static-check --- python/paddle/base/type_promotion.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py index 5adcb127e71bb0..fd36f554028f88 100644 --- a/python/paddle/base/type_promotion.py +++ b/python/paddle/base/type_promotion.py @@ -65,8 +65,6 @@ def get_result_dtype(x_dtype, y_dtype): try: return promoteTypesLookup[Number[x_dtype]][Number[y_dtype]] except: - print( - "got unsupport dtype for type promotion: {} and {}.".format( - x_dtype, y_dtype - ) + raise TypeError( + f"got unsupport dtype for type promotion: {x_dtype} and {y_dtype}." ) From 3cafffa8e2598a500c1200ffbdb1be678953c618 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Nov 2023 06:36:51 +0000 Subject: [PATCH 07/27] add eager T+T logic. --- .../forwards/multiply_fwd_func.cc | 31 ++++ .../generator/eager_gen.py | 97 ++++++++++++- paddle/fluid/eager/type_promotion_utils.h | 134 ++++++++++++++++++ paddle/fluid/imperative/type_promotion.cc | 42 ++++++ paddle/fluid/pybind/eager_math_op_patch.cc | 133 +---------------- .../test_math_op_patch_var_base.py | 82 ++++++++++- 6 files changed, 389 insertions(+), 130 deletions(-) create mode 100644 paddle/fluid/eager/type_promotion_utils.h create mode 100644 paddle/fluid/imperative/type_promotion.cc diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc index 092620120cae19..5542c5e51eac1a 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/eager/eager_amp_auto_cast.h" #include "paddle/fluid/eager/eager_layout_auto_tune.h" #include "paddle/fluid/eager/nan_inf_utils.h" +#include "paddle/fluid/eager/type_promotion_utils.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/api/include/sparse_api.h" @@ -56,6 +57,21 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } } + // Type promotion Logic + paddle::small_vector, egr::kSlotSmallVectorSize> + promote_tensors_vector = {{x}, {y}}; + if (egr::NeedTypePromotion(promote_tensors_vector)) { + VLOG(5) << "got different data type, run type protmotion automatically."; + auto op_name = phi::TransToFluidOpName("add"); + + auto promotion_type = egr::GetPromoteDtype(op_name, promote_tensors_vector); + + auto new_x = egr::PromoteCast("x", x, promotion_type); + auto new_y = egr::PromoteCast("y", y, promotion_type); + + return multiply_ad_func(new_x, new_y); + } + // Layout autotune if (egr::Controller::Instance().UseLayoutAutoTune()) { @@ -388,6 +404,21 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } } + // Type promotion Logic + paddle::small_vector, egr::kSlotSmallVectorSize> + promote_tensors_vector = {{x}, {y}}; + if (egr::NeedTypePromotion(promote_tensors_vector)) { + VLOG(5) << "got different data type, run type protmotion automatically."; + auto op_name = phi::TransToFluidOpName("add"); + + auto promotion_type = egr::GetPromoteDtype(op_name, promote_tensors_vector); + + auto new_x = egr::PromoteCast("x", x, promotion_type); + auto new_y = egr::PromoteCast("y", y, promotion_type); + + return multiply_ad_func(new_x, new_y); + } + // Layout autotune if (egr::Controller::Instance().UseLayoutAutoTune()) { diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index ff1758e3ef93a4..40ac64b7befe08 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -75,6 +75,13 @@ "tanh_triple_grad", ] +# white ops list whose kernel can automaically do type promotion. +type_promote_white_list = [ + "add", + "subtract", + "greater_than", +] + # dict of special api that forward api's output will affect bacward api's output # bacward api's output usually affected by backward api's input special_prune_dict = { @@ -247,6 +254,8 @@ class {} : public egr::GradNodeBase {{ // Dygraph Record Event {} // AMP Logic +{} + // Type promotion Logic {} // Layout autotune {} @@ -315,6 +324,8 @@ class {} : public egr::GradNodeBase {{ // Dygraph Record Event {} // AMP Logic +{} + // Type promotion Logic {} // Layout autotune {} @@ -447,7 +458,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/api/lib/data_transform.h" - +#include "paddle/fluid/eager/type_promotion_utils.h" PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_string(tensor_operants_mode); {} @@ -512,6 +523,33 @@ class {} : public egr::GradNodeBase {{ }} }} """ + +# PROMOTION_LOGIC_TEMPLATE = """ if (egr::Controller::Instance().UseTypePromotion()) {{ +# VLOG(5) << "got different data type, run type protmotion automatically."; +# {} +# paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; +# {} +# {} +# {} +# {{ +# paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), false); +# {} +# }} +# }} +# """ + +PROMOTION_LOGIC_TEMPLATE = """ paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; + if (egr::NeedTypePromotion(promote_tensors_vector)) {{ + VLOG(5) << "got different data type, run type protmotion automatically."; + {} + {} + {} + {} + {} + }} +""" + + LAYOUT_LOGIC_TEMPLATE = """ if (egr::Controller::Instance().UseLayoutAutoTune()) {{ paddle::small_vector, egr::kSlotSmallVectorSize> tensors_vector = {}; @@ -1463,6 +1501,10 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_tensors_vector_optional_list = [] amp_autocast_list = [] amp_autocast_optional_list = [] + type_promote_vector_list = [] + type_promote_vector_optional_list = [] + type_promote_list = [] + type_promote_optional_list = [] layout_autotune_list = [] layout_autotune_optional_list = [] layout_tensors_vector_optional_list = [] @@ -1489,6 +1531,12 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_optional_list.append( f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) + type_promote_vector_optional_list.append( + f"if ({name}) promote_tensors_vector.push_back({{ *{name} }});\n" + ) + type_promote_optional_list.append( + f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" + ) layout_tensors_vector_optional_list.append( f"if ({name}) tensors_vector.push_back({{ *{name} }});\n" ) @@ -1512,6 +1560,10 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_list.append( f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) + type_promote_vector_list.append(f"{name}") + type_promote_list.append( + f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" + ) layout_autotune_list.append( f"auto new_{name} = transformer->TransInTensor(\"{name}\", {name});\n" ) @@ -1533,6 +1585,12 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_optional_list.append( f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) + type_promote_vector_optional_list.append( + f"if ({name}) promote_tensors_vector.push_back( *{name} );\n" + ) + type_promote_optional_list.append( + f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" + ) layout_autotune_optional_list.append( f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n" ) @@ -1549,6 +1607,10 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_list.append( f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) + type_promote_vector_list.append(f"{name}") + type_promote_list.append( + f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" + ) layout_autotune_list.append( f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n" ) @@ -1804,7 +1866,31 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_list_str, amp_call_str, ) - + # Forward type promotion logic + if forward_api_name in type_promote_white_list: + type_promote_get_dst_dtype_str = "auto promotion_type = egr::GetPromoteDtype(op_name, promote_tensors_vector);\n" + type_promote_vector_optional_list_str = " ".join( + type_promote_vector_optional_list + ) + type_promote_list_str = ( + " ".join(type_promote_list) + + " " + + " ".join(type_promote_optional_list) + ) + type_promotion_logic_str = PROMOTION_LOGIC_TEMPLATE.format( + amp_tensors_vector_list_str, + kernel_trans2_op_name_str, + type_promote_vector_optional_list_str, + type_promote_get_dst_dtype_str, + type_promote_list_str, + amp_call_str, + ) + else: + type_promotion_logic_str = ( + "\n VLOG(5) << \" No Promotion for {} api. \"; ".format( + forward_ad_function_name + ) + ) # Forward layout autotune layout_autotune_list_str = " ".join( layout_autotune_list @@ -1841,6 +1927,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_logic_str = "\n VLOG(7) << \" No AMP for {} because it has no input. \"; ".format( forward_ad_function_name ) + type_promotion_logic_str = ( + "\n VLOG(7) << \" No Promotion for {} api. \"; ".format( + forward_ad_function_name + ) + ) self.forward_definition_str += ( FORWARD_ONLY_FUNCTION_TEMPLATE.format( returns_type_str, @@ -1849,6 +1940,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): forward_api_name, dygraph_event_str, amp_logic_str, + type_promotion_logic_str, layout_logic_str, forward_api_name, before_log_str, @@ -1871,6 +1963,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): forward_api_name, dygraph_event_str, amp_logic_str, + type_promotion_logic_str, layout_logic_str, inputs_autograd_meta_str, forward_api_name, diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h new file mode 100644 index 00000000000000..58ab6bf4d0afb3 --- /dev/null +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -0,0 +1,134 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/imperative/type_promotion.h" + +namespace egr { + +inline int DataTypeToNum(const phi::DataType& dtype) { + switch (dtype) { + case phi::DataType::UINT8: + return 0; + case phi::DataType::INT8: + return 1; + case phi::DataType::INT16: + return 2; + case phi::DataType::INT32: + return 3; + case phi::DataType::INT64: + return 4; + case phi::DataType::FLOAT16: + return 5; + case phi::DataType::FLOAT32: + return 6; + case phi::DataType::FLOAT64: + return 7; + case phi::DataType::COMPLEX64: + return 8; + case phi::DataType::COMPLEX128: + return 9; + case phi::DataType::BOOL: + return 10; + case phi::DataType::BFLOAT16: + return 11; + default: + PD_THROW("Invalid enum data type for type promote `", dtype, "`."); + } +} + +inline paddle::Tensor PromoteCast(const std::string& input_name, + const paddle::Tensor& input, + const phi::DataType& dst_dtype, + bool trace_backward = true) { + if (input.dtype() != dst_dtype) { + return Cast(input, dst_dtype, trace_backward); + } else { + return input; + } +} + +static inline bool is_support_float(phi::DataType dtype) { + if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::FLOAT32 || + dtype == phi::DataType::FLOAT64 || dtype == phi::DataType::BFLOAT16) { + return true; + } else { + return false; + } +} + +static inline bool is_support_int(phi::DataType dtype) { + if (dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64) { + return true; + } else { + return false; + } +} + +inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) { + constexpr auto u1 = phi::DataType::UINT8; + constexpr auto i1 = phi::DataType::INT8; + constexpr auto i2 = phi::DataType::INT16; + constexpr auto i4 = phi::DataType::INT32; + constexpr auto i8 = phi::DataType::INT64; + constexpr auto f2 = phi::DataType::FLOAT16; + constexpr auto f4 = phi::DataType::FLOAT32; + constexpr auto f8 = phi::DataType::FLOAT64; + constexpr auto c4 = phi::DataType::COMPLEX64; + constexpr auto c8 = phi::DataType::COMPLEX128; + constexpr auto b1 = phi::DataType::BOOL; + constexpr auto bf = phi::DataType::BFLOAT16; + + static constexpr phi::DataType _promoteTypesLookup[12][12] = { + /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, + }; + return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; +} + +inline phi::DataType GetPromoteDtype( + const std::string& op_name, + const paddle::small_vector, + kSlotSmallVectorSize>& promote_tensors_vector) { + return promoteTypes(promote_tensors_vector[0][0].dtype(), + promote_tensors_vector[1][0].dtype()); +} + +inline bool NeedTypePromotion( + const paddle::small_vector, + kSlotSmallVectorSize>& promote_tensors_vector) { + // T+T only support type promotion between float, int32, int64 + if ((promote_tensors_vector[0][0].dtype() != + promote_tensors_vector[1][0].dtype()) && + (is_support_float(a) || is_support_int(a)) && + (is_support_float(b) || is_support_int(b))) { + return true; + } else { + return false; + } +} + +} // namespace egr diff --git a/paddle/fluid/imperative/type_promotion.cc b/paddle/fluid/imperative/type_promotion.cc new file mode 100644 index 00000000000000..a1d1e07b47b797 --- /dev/null +++ b/paddle/fluid/imperative/type_promotion.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/imperative/type_promotion.h" +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/imperative/var_helper.h" + +namespace paddle { +namespace imperative { +TypePrmotionGuard::TypePrmotionGuard(std::shared_ptr tracer, + bool use_type_promotion_) + : tracer_(tracer) { + pre_type_promotion = tracer_->UseTypePromotion(); + if (pre_type_promotion != use_type_promotion_) { + tracer_->EnableTypePromotion(); + if (!use_type_promotion_) { + tracer_->DisableTypePromotion(); + } + } +} + +TypePrmotionGuard::~TypePrmotionGuard() { + if (pre_type_promotion) { + tracer_->EnableTypePromotion(); + } else { + tracer_->DisableTypePromotion(); + } +} +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index aa7a27db207364..524d08eae2fe20 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -252,36 +252,7 @@ static PyObject* tensor__add__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types or unify right var type to left var - phi::DataType lhs_dtype = self_tensor.dtype(); - phi::DataType rhs_dtype = other_tensor.dtype(); - if (lhs_dtype != rhs_dtype) { - // note: only op_type in _supported_promote_complex_types_ should promote - // dtype - if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || - _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { - phi::DataType promote_dtype = - framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( - framework::TransToProtoVarType(lhs_dtype), - framework::TransToProtoVarType(rhs_dtype))); - if (lhs_dtype != promote_dtype) { - // cast - eager_gil_scoped_release guard; - self_tensor = cast_ad_func(self_tensor, promote_dtype); - } - if (rhs_dtype != promote_dtype) { - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, promote_dtype); - } - } else { - VLOG(6) << "The dtype of left and right Tensor are not the same, left " - "dtype is " - << lhs_dtype << ", but right dtype is " << rhs_dtype - << ", the right dtype will convert to " << lhs_dtype; - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, lhs_dtype); - } - } + // 3. promote types move to add_ad_func // 4. calculation VLOG(6) << "Calling add_ad_func in tensor__add__method"; @@ -358,34 +329,8 @@ static PyObject* tensor__sub__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types or unify right var type to left var - phi::DataType lhs_dtype = self_tensor.dtype(); - phi::DataType rhs_dtype = other_tensor.dtype(); - if (lhs_dtype != rhs_dtype) { - if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || - _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { - phi::DataType promote_dtype = - framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( - framework::TransToProtoVarType(lhs_dtype), - framework::TransToProtoVarType(rhs_dtype))); - if (lhs_dtype != promote_dtype) { - // cast - eager_gil_scoped_release guard; - self_tensor = cast_ad_func(self_tensor, promote_dtype); - } - if (rhs_dtype != promote_dtype) { - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, promote_dtype); - } - } else { - VLOG(6) << "The dtype of left and right Tensor are not the same, left " - "dtype is " - << lhs_dtype << ", but right dtype is " << rhs_dtype - << ", the right dtype will convert to " << lhs_dtype; - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, lhs_dtype); - } - } + // 3. promote types move to subtract_ad_func + // 4. calculation VLOG(6) << "Calling subtract_ad_func in tensor__sub__method"; { @@ -460,34 +405,7 @@ static PyObject* tensor__rsub__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types or unify right var type to left var - phi::DataType lhs_dtype = self_tensor.dtype(); - phi::DataType rhs_dtype = other_tensor.dtype(); - if (lhs_dtype != rhs_dtype) { - if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || - _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { - phi::DataType promote_dtype = - framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( - framework::TransToProtoVarType(lhs_dtype), - framework::TransToProtoVarType(rhs_dtype))); - if (lhs_dtype != promote_dtype) { - // cast - eager_gil_scoped_release guard; - self_tensor = cast_ad_func(self_tensor, promote_dtype); - } - if (rhs_dtype != promote_dtype) { - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, promote_dtype); - } - } else { - VLOG(6) << "The dtype of left and right Tensor are not the same, left " - "dtype is " - << lhs_dtype << ", but right dtype is " << rhs_dtype - << ", the right dtype will convert to " << lhs_dtype; - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, lhs_dtype); - } - } + // 3. promote types move to subtract_ad_func // 4. calculation VLOG(6) << "Calling subtract_ad_func in tensor__rsub__method"; @@ -568,36 +486,7 @@ static PyObject* tensor__mul__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types or unify right var type to left var - phi::DataType lhs_dtype = self_tensor.dtype(); - phi::DataType rhs_dtype = other_tensor.dtype(); - if (lhs_dtype != rhs_dtype) { - // note: only op_type in _supported_promote_complex_types_ should promote - // dtype - if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || - _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { - phi::DataType promote_dtype = - framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( - framework::TransToProtoVarType(lhs_dtype), - framework::TransToProtoVarType(rhs_dtype))); - if (lhs_dtype != promote_dtype) { - // cast - eager_gil_scoped_release guard; - self_tensor = cast_ad_func(self_tensor, promote_dtype); - } - if (rhs_dtype != promote_dtype) { - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, promote_dtype); - } - } else { - VLOG(6) << "The dtype of left and right Tensor are not the same, left " - "dtype is " - << lhs_dtype << ", but right dtype is " << rhs_dtype - << ", the right dtype will convert to " << lhs_dtype; - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, lhs_dtype); - } - } + // 3. promote types move to multiply_ad_func // 4. calculation VLOG(6) << "Calling multiply_ad_func in tensor__mul__method"; @@ -927,17 +816,7 @@ static PyObject* tensor__gt__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types or unify right var type to left var - phi::DataType lhs_dtype = self_tensor.dtype(); - phi::DataType rhs_dtype = other_tensor.dtype(); - if (lhs_dtype != rhs_dtype) { - VLOG(6) << "The dtype of left and right Tensor are not the same, left " - "dtype is " - << lhs_dtype << ", but right dtype is " << rhs_dtype - << ", the right dtype will convert to " << lhs_dtype; - eager_gil_scoped_release guard; - other_tensor = cast_ad_func(other_tensor, lhs_dtype); - } + // 3. promote types move to greater_than_ad_func // 4. calculation VLOG(6) << "Calling greater_than_ad_func in tensor__gt__method"; diff --git a/test/legacy_test/test_math_op_patch_var_base.py b/test/legacy_test/test_math_op_patch_var_base.py index af5fbd9ba9ca1c..f7a63a0ee3d91f 100644 --- a/test/legacy_test/test_math_op_patch_var_base.py +++ b/test/legacy_test/test_math_op_patch_var_base.py @@ -35,6 +35,28 @@ def test_add(self): res = a + b np.testing.assert_array_equal(res.numpy(), a_np + b_np) + def test_type_promotion_add_F_F(self): + a_np = np.random.random(self.shape).astype(np.float32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a + b + res_t = b + a + np.testing.assert_array_equal(res_t.numpy(), res.numpy()) + np.testing.assert_array_equal(res.numpy(), a_np + b_np) + + def test_type_promotion_add_F_I(self): + a_np = np.random.random(self.shape).astype(np.int32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a + b + res_t = b + a + np.testing.assert_array_equal(res_t.numpy(), res.numpy()) + np.testing.assert_array_equal(res.numpy(), a_np + b_np) + def test_sub(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) @@ -43,7 +65,25 @@ def test_sub(self): b = base.dygraph.to_variable(b_np) res = a - b np.testing.assert_array_equal(res.numpy(), a_np - b_np) - + + def test_type_promotion_sub_F_F(self): + a_np = np.random.random(self.shape).astype(np.float32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a - b + np.testing.assert_array_equal(res.numpy(), a_np - b_np) + + def test_type_promotion_sub_F_I(self): + a_np = np.random.random(self.shape).astype(np.int32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a - b + np.testing.assert_array_equal(res.numpy(), a_np - b_np) + def test_mul(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) @@ -53,6 +93,28 @@ def test_mul(self): res = a * b np.testing.assert_array_equal(res.numpy(), a_np * b_np) + def test_type_promotion_mul_F_F(self): + a_np = np.random.random(self.shape).astype(np.float32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a * b + res_t = b * a + np.testing.assert_array_equal(res_t.numpy(), res.numpy()) + np.testing.assert_array_equal(res.numpy(), a_np * b_np) + + def test_type_promotion_mul_F_I(self): + a_np = np.random.random(self.shape).astype(np.int32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a * b + res_t = b * a + np.testing.assert_array_equal(res_t.numpy(), res.numpy()) + np.testing.assert_array_equal(res.numpy(), a_np * b_np) + def test_div(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) @@ -219,6 +281,24 @@ def test_greater_than(self): res = a > b np.testing.assert_array_equal(res.numpy(), a_np > b_np) + def test_type_promotion_greater_than_F_F(self): + a_np = np.random.random(self.shape).astype(np.float32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a > b + np.testing.assert_array_equal(res.numpy(), a_np > b_np) + + def test_type_promotion_greater_than_F_I(self): + a_np = np.random.random(self.shape).astype(np.int32) + b_np = np.random.random(self.shape).astype(np.float16) + with base.dygraph.guard(): + a = base.dygraph.to_variable(a_np) + b = base.dygraph.to_variable(b_np) + res = a > b + np.testing.assert_array_equal(res.numpy(), a_np > b_np) + def test_greater_equal(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) From a1c649a92afe0adb7184df48f4d7c4e968eb6a8f Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Nov 2023 06:37:34 +0000 Subject: [PATCH 08/27] remove useless file. --- paddle/fluid/imperative/type_promotion.cc | 42 ----------------------- 1 file changed, 42 deletions(-) delete mode 100644 paddle/fluid/imperative/type_promotion.cc diff --git a/paddle/fluid/imperative/type_promotion.cc b/paddle/fluid/imperative/type_promotion.cc deleted file mode 100644 index a1d1e07b47b797..00000000000000 --- a/paddle/fluid/imperative/type_promotion.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/fluid/imperative/type_promotion.h" -#include "paddle/fluid/eager/eager_tensor.h" -#include "paddle/fluid/imperative/tracer.h" -#include "paddle/fluid/imperative/type_defs.h" -#include "paddle/fluid/imperative/var_helper.h" - -namespace paddle { -namespace imperative { -TypePrmotionGuard::TypePrmotionGuard(std::shared_ptr tracer, - bool use_type_promotion_) - : tracer_(tracer) { - pre_type_promotion = tracer_->UseTypePromotion(); - if (pre_type_promotion != use_type_promotion_) { - tracer_->EnableTypePromotion(); - if (!use_type_promotion_) { - tracer_->DisableTypePromotion(); - } - } -} - -TypePrmotionGuard::~TypePrmotionGuard() { - if (pre_type_promotion) { - tracer_->EnableTypePromotion(); - } else { - tracer_->DisableTypePromotion(); - } -} -} // namespace imperative -} // namespace paddle From 4ce90342e3987c746957a37ce7f2c591d9f14654 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Nov 2023 06:39:25 +0000 Subject: [PATCH 09/27] remove useless line. --- .../auto_code_generator/generator/eager_gen.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 40ac64b7befe08..ccda5c8f34a746 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -524,20 +524,6 @@ class {} : public egr::GradNodeBase {{ }} """ -# PROMOTION_LOGIC_TEMPLATE = """ if (egr::Controller::Instance().UseTypePromotion()) {{ -# VLOG(5) << "got different data type, run type protmotion automatically."; -# {} -# paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; -# {} -# {} -# {} -# {{ -# paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), false); -# {} -# }} -# }} -# """ - PROMOTION_LOGIC_TEMPLATE = """ paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; if (egr::NeedTypePromotion(promote_tensors_vector)) {{ VLOG(5) << "got different data type, run type protmotion automatically."; From 359a689e17355f2d9741d15bc0a12e2986f94e35 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Nov 2023 06:53:22 +0000 Subject: [PATCH 10/27] fix --- paddle/fluid/eager/type_promotion_utils.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 58ab6bf4d0afb3..0c2c75b9d991b0 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -120,10 +120,10 @@ inline phi::DataType GetPromoteDtype( inline bool NeedTypePromotion( const paddle::small_vector, kSlotSmallVectorSize>& promote_tensors_vector) { - // T+T only support type promotion between float, int32, int64 - if ((promote_tensors_vector[0][0].dtype() != - promote_tensors_vector[1][0].dtype()) && - (is_support_float(a) || is_support_int(a)) && + // only support type promotion between float, int32, int64 + phi::DataType a = promote_tensors_vector[0][0].dtype(); + phi::DataType b = promote_tensors_vector[1][0].dtype(); + if ((a != b) && (is_support_float(a) || is_support_int(a)) && (is_support_float(b) || is_support_int(b))) { return true; } else { From 5af5d8cc9f5f3a4b79c5a477a25ede802e8a82af Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 29 Nov 2023 07:50:07 +0000 Subject: [PATCH 11/27] dtype promotion for operator overload in static mode --- python/paddle/base/layers/math_op_patch.py | 12 +- .../legacy_test/test_tensor_type_promotion.py | 226 ++++++++++++++++++ 2 files changed, 236 insertions(+), 2 deletions(-) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 1f070882758b92..956c3435f4fca4 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -519,10 +519,18 @@ def __impl__(self, other_var): current_block(self), value=other_var, dtype=lhs_dtype ) - # 3. unify right var type to left var + # 3. type promotion rhs_dtype = safe_get_dtype(other_var) + if lhs_dtype != rhs_dtype: - other_var = astype(other_var, lhs_dtype) + from ..type_promotion import get_result_dtype + + common_dtype = get_result_dtype(lhs_dtype, rhs_dtype) + if rhs_dtype != common_dtype: + other_var = astype(other_var, common_dtype) + if lhs_dtype != common_dtype: + self = astype(self, common_dtype) + if reverse: tmp = self self = other_var diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index a4e3f76d7ee8be..b322ad8d485577 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -51,5 +51,231 @@ def test_operator(self): self.div_operator() +def create_test_case(baseclass, ldtype, rdtype, expected_out_dtype=None): + class TestPromotion(baseclass): + def set_dtype(self): + self.ldtype = ldtype + self.rdtype = rdtype + self.expected_out_dtype = expected_out_dtype + + cls_name = f"{baseclass.__name__}Between{ldtype}And{rdtype}" + TestPromotion.__name__ = cls_name + globals()[cls_name] = TestPromotion + + +class TestOperatorOverloadAddInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.set_dtype() + self.exe = paddle.static.Executor() + + def set_dtype(self): + self.ldtype = 'float32' + self.rdtype = 'float64' + self.expected_out_dtype = 'float64' + + def generate_test_value(self): + self.l_value = (paddle.randn((4, 3, 2)) * 10).astype(self.ldtype) + self.r_value = (paddle.randn((4, 3, 2)) * 10).astype(self.rdtype) + + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value + self.r_value + out_reverse = self.r_value + self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + def test_dtype_is_expected(self): + res = self.run_api() + self.assertEqual(res[0].dtype.__str__(), self.expected_out_dtype) + self.assertEqual(res[1].dtype.__str__(), self.expected_out_dtype) + + +create_test_case( + TestOperatorOverloadAddInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float16', 'float64', 'float64' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float16', 'complex64', 'complex64' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float16', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadAddInStatic, 'float32', 'float64', 'float64' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float32', 'complex64', 'complex64' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float32', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadAddInStatic, 'float64', 'complex64', 'complex128' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float64', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadAddInStatic, 'complex64', 'complex128', 'complex128' +) + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'complex128', 'complex128' + ) + + +class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic): + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value - self.r_value + out_reverse = self.r_value - self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case( + TestOperatorOverloadSubInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float16', 'float64', 'float64' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float16', 'complex64', 'complex64' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float16', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadSubInStatic, 'float32', 'float64', 'float64' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float32', 'complex64', 'complex64' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float32', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadSubInStatic, 'float64', 'complex64', 'complex128' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float64', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadSubInStatic, 'complex64', 'complex128', 'complex128' +) + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'complex128', 'complex128' + ) + + +class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value * self.r_value + out_reverse = self.r_value * self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'complex64', 'complex64' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float32', 'complex64', 'complex64' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float32', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadMulInStatic, 'float64', 'complex64', 'complex128' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float64', 'complex128', 'complex128' +) + +create_test_case( + TestOperatorOverloadMulInStatic, 'complex64', 'complex128', 'complex128' +) + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'complex128', 'complex128' + ) + + if __name__ == '__main__': unittest.main() From 1b069a2efdea42be7b7ace3d0e0972886bbe1ceb Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Wed, 29 Nov 2023 10:10:35 +0000 Subject: [PATCH 12/27] only support float series --- python/paddle/base/layers/math_op_patch.py | 15 ++--- python/paddle/base/type_promotion.py | 12 ++++ .../legacy_test/test_tensor_type_promotion.py | 58 +++++++++++++++++++ 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 956c3435f4fca4..589d29799f90d9 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -523,13 +523,14 @@ def __impl__(self, other_var): rhs_dtype = safe_get_dtype(other_var) if lhs_dtype != rhs_dtype: - from ..type_promotion import get_result_dtype - - common_dtype = get_result_dtype(lhs_dtype, rhs_dtype) - if rhs_dtype != common_dtype: - other_var = astype(other_var, common_dtype) - if lhs_dtype != common_dtype: - self = astype(self, common_dtype) + from ..type_promotion import get_result_dtype, is_support_float + + if is_support_float(lhs_dtype) and is_support_float(rhs_dtype): + common_dtype = get_result_dtype(lhs_dtype, rhs_dtype) + if rhs_dtype != common_dtype: + other_var = astype(other_var, common_dtype) + if lhs_dtype != common_dtype: + self = astype(self, common_dtype) if reverse: tmp = self diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py index fd36f554028f88..4080c245845ce7 100644 --- a/python/paddle/base/type_promotion.py +++ b/python/paddle/base/type_promotion.py @@ -58,6 +58,10 @@ ] +SUPPORT_FLOAT = [dtype.float16, dtype.float32, dtype.float64, dtype.bfloat16] +SUPPORT_INT = [dtype.int32, dtype.int64] + + def get_result_dtype(x_dtype, y_dtype): if x_dtype == y_dtype: return x_dtype @@ -68,3 +72,11 @@ def get_result_dtype(x_dtype, y_dtype): raise TypeError( f"got unsupport dtype for type promotion: {x_dtype} and {y_dtype}." ) + + +def is_support_float(dtype): + return dtype in SUPPORT_FLOAT + + +def is_support_int(dtype): + return dtype in SUPPORT_INT diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index b322ad8d485577..6eae66b3120389 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -277,5 +277,63 @@ def run_api(self): ) +class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic): + def set_dtype(self): + self.ldtype = 'float32' + self.rdtype = 'float64' + self.expected_out_dtype = 'bool' + + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value > self.r_value + out_reverse = self.r_value > self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool') +create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool') +create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'complex64', 'bool') +create_test_case( + TestOperatorOverloadGTInStatic, 'float16', 'complex128', 'bool' +) + +create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool') +create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'complex64', 'bool') +create_test_case( + TestOperatorOverloadGTInStatic, 'float32', 'complex128', 'bool' +) + +create_test_case(TestOperatorOverloadGTInStatic, 'float64', 'complex64', 'bool') +create_test_case( + TestOperatorOverloadGTInStatic, 'float64', 'complex128', 'bool' +) + +create_test_case( + TestOperatorOverloadGTInStatic, 'complex64', 'complex128', 'bool' +) + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool' + ) + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool' + ) + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool' + ) + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'complex64', 'bool' + ) + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'complex128', 'bool' + ) + + if __name__ == '__main__': unittest.main() From 83ec3e016e140f56eeaaf63965f72559bc36acdf Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Nov 2023 11:56:50 +0000 Subject: [PATCH 13/27] update --- .../generator/eager_gen.py | 4 +- paddle/fluid/eager/type_promotion_utils.h | 90 +++++++------------ paddle/phi/common/type_promotion_table.h | 85 ++++++++++++++++++ 3 files changed, 117 insertions(+), 62 deletions(-) create mode 100644 paddle/phi/common/type_promotion_table.h diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index ccda5c8f34a746..5e5704bcef43e6 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -524,7 +524,7 @@ class {} : public egr::GradNodeBase {{ }} """ -PROMOTION_LOGIC_TEMPLATE = """ paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; +TYPE_PROMOTION_LOGIC_TEMPLATE = """ paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; if (egr::NeedTypePromotion(promote_tensors_vector)) {{ VLOG(5) << "got different data type, run type protmotion automatically."; {} @@ -1863,7 +1863,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): + " " + " ".join(type_promote_optional_list) ) - type_promotion_logic_str = PROMOTION_LOGIC_TEMPLATE.format( + type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format( amp_tensors_vector_list_str, kernel_trans2_op_name_str, type_promote_vector_optional_list_str, diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index dd0633581037af..6abd0062dd7a66 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -12,39 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once -#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" + #include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/phi/common/type_promotion_table.h" namespace egr { -inline int DataTypeToNum(const phi::DataType& dtype) { - switch (dtype) { - case phi::DataType::UINT8: - return 0; - case phi::DataType::INT8: - return 1; - case phi::DataType::INT16: - return 2; - case phi::DataType::INT32: - return 3; - case phi::DataType::INT64: - return 4; - case phi::DataType::FLOAT16: - return 5; - case phi::DataType::FLOAT32: - return 6; - case phi::DataType::FLOAT64: - return 7; - case phi::DataType::COMPLEX64: - return 8; - case phi::DataType::COMPLEX128: - return 9; - case phi::DataType::BOOL: - return 10; - case phi::DataType::BFLOAT16: - return 11; - default: - PD_THROW("Invalid enum data type for type promote `", dtype, "`."); +inline paddle::Tensor PromoteCast(const std::string& input_name, + const paddle::Tensor& input, + const phi::DataType& dst_dtype, + bool trace_backward = true) { + if (input.dtype() != dst_dtype) { + return Cast(input, dst_dtype, trace_backward); + } else { + return input; } } @@ -65,37 +46,26 @@ static inline bool is_support_int(phi::DataType dtype) { } } -inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) { - constexpr auto u1 = phi::DataType::UINT8; - constexpr auto i1 = phi::DataType::INT8; - constexpr auto i2 = phi::DataType::INT16; - constexpr auto i4 = phi::DataType::INT32; - constexpr auto i8 = phi::DataType::INT64; - constexpr auto f2 = phi::DataType::FLOAT16; - constexpr auto f4 = phi::DataType::FLOAT32; - constexpr auto f8 = phi::DataType::FLOAT64; - constexpr auto c4 = phi::DataType::COMPLEX64; - constexpr auto c8 = phi::DataType::COMPLEX128; - constexpr auto b1 = phi::DataType::BOOL; - constexpr auto bf = phi::DataType::BFLOAT16; - - static constexpr phi::DataType _promoteTypesLookup[12][12] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, - /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, - }; +inline phi::DataType GetPromoteDtype( + const std::string& op_name, + const paddle::small_vector, + kSlotSmallVectorSize>& promote_tensors_vector) { + return phi::promoteTypes(promote_tensors_vector[0][0].dtype(), + promote_tensors_vector[1][0].dtype()); +} - return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; +inline bool NeedTypePromotion( + const paddle::small_vector, + kSlotSmallVectorSize>& promote_tensors_vector) { + // Tensor + Tensor only support type promotion between float, int32, int64 + phi::DataType a = promote_tensors_vector[0][0].dtype(); + phi::DataType b = promote_tensors_vector[1][0].dtype(); + if ((a != b) && (is_support_float(a) || is_support_int(a)) && + (is_support_float(b) || is_support_int(b))) { + return true; + } else { + return false; + } } -} // namespace egr \ No newline at end of file +} // namespace egr diff --git a/paddle/phi/common/type_promotion_table.h b/paddle/phi/common/type_promotion_table.h new file mode 100644 index 00000000000000..3d86d8636a208e --- /dev/null +++ b/paddle/phi/common/type_promotion_table.h @@ -0,0 +1,85 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/common/data_type.h" +namespace phi { + +inline int DataTypeToNum(const DataType& dtype) { + switch (dtype) { + case DataType::UINT8: + return 0; + case DataType::INT8: + return 1; + case DataType::INT16: + return 2; + case DataType::INT32: + return 3; + case DataType::INT64: + return 4; + case DataType::FLOAT16: + return 5; + case DataType::FLOAT32: + return 6; + case DataType::FLOAT64: + return 7; + case DataType::COMPLEX64: + return 8; + case DataType::COMPLEX128: + return 9; + case DataType::BOOL: + return 10; + case DataType::BFLOAT16: + return 11; + default: + PD_THROW("Invalid enum data type for type promote `", dtype, "`."); + } +} + +inline static DataType promoteTypes(DataType a, DataType b) { + constexpr auto u1 = DataType::UINT8; + constexpr auto i1 = DataType::INT8; + constexpr auto i2 = DataType::INT16; + constexpr auto i4 = DataType::INT32; + constexpr auto i8 = DataType::INT64; + constexpr auto f2 = DataType::FLOAT16; + constexpr auto f4 = DataType::FLOAT32; + constexpr auto f8 = DataType::FLOAT64; + constexpr auto c4 = DataType::COMPLEX64; + constexpr auto c8 = DataType::COMPLEX128; + constexpr auto b1 = DataType::BOOL; + constexpr auto bf = DataType::BFLOAT16; + + const int total_type_num = 12; + + static constexpr DataType + _promoteTypesLookup[total_type_num][total_type_num] = { + /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, + }; + return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; +} + +} // namespace phi From 2636832d949a39723997ec730fdce910eaf855e7 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 29 Nov 2023 12:06:14 +0000 Subject: [PATCH 14/27] fix note. --- paddle/fluid/eager/type_promotion_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 6abd0062dd7a66..4cdb23c3629eed 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -57,7 +57,7 @@ inline phi::DataType GetPromoteDtype( inline bool NeedTypePromotion( const paddle::small_vector, kSlotSmallVectorSize>& promote_tensors_vector) { - // Tensor + Tensor only support type promotion between float, int32, int64 + // Tensor + Tensor only support type promotion in float, int32, int64 phi::DataType a = promote_tensors_vector[0][0].dtype(); phi::DataType b = promote_tensors_vector[1][0].dtype(); if ((a != b) && (is_support_float(a) || is_support_int(a)) && From 7c4b08e592722bd8cbdab84317ff2296c33a61b4 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 30 Nov 2023 03:27:19 +0000 Subject: [PATCH 15/27] mv common logic to common dir. --- paddle/fluid/eager/type_promotion_utils.h | 39 ----------------------- paddle/phi/common/type_promotion_table.h | 39 +++++++++++++++++++++++ 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 4cdb23c3629eed..0d890b58bf50b9 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -29,43 +29,4 @@ inline paddle::Tensor PromoteCast(const std::string& input_name, } } -static inline bool is_support_float(phi::DataType dtype) { - if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::FLOAT32 || - dtype == phi::DataType::FLOAT64 || dtype == phi::DataType::BFLOAT16) { - return true; - } else { - return false; - } -} - -static inline bool is_support_int(phi::DataType dtype) { - if (dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64) { - return true; - } else { - return false; - } -} - -inline phi::DataType GetPromoteDtype( - const std::string& op_name, - const paddle::small_vector, - kSlotSmallVectorSize>& promote_tensors_vector) { - return phi::promoteTypes(promote_tensors_vector[0][0].dtype(), - promote_tensors_vector[1][0].dtype()); -} - -inline bool NeedTypePromotion( - const paddle::small_vector, - kSlotSmallVectorSize>& promote_tensors_vector) { - // Tensor + Tensor only support type promotion in float, int32, int64 - phi::DataType a = promote_tensors_vector[0][0].dtype(); - phi::DataType b = promote_tensors_vector[1][0].dtype(); - if ((a != b) && (is_support_float(a) || is_support_int(a)) && - (is_support_float(b) || is_support_int(b))) { - return true; - } else { - return false; - } -} - } // namespace egr diff --git a/paddle/phi/common/type_promotion_table.h b/paddle/phi/common/type_promotion_table.h index 3d86d8636a208e..a1fdde9c6f7e9f 100644 --- a/paddle/phi/common/type_promotion_table.h +++ b/paddle/phi/common/type_promotion_table.h @@ -82,4 +82,43 @@ inline static DataType promoteTypes(DataType a, DataType b) { return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; } +static inline bool is_support_float(DataType dtype) { + if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 || + dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) { + return true; + } else { + return false; + } +} + +static inline bool is_support_int(phi::DataType dtype) { + if (dtype == DataType::INT32 || dtype == DataType::INT64) { + return true; + } else { + return false; + } +} + +inline phi::DataType GetPromoteDtype(const std::string& op_name, + const DataType x, + const DataType y) { + // future will deal this by different rule + if (op_name == "greater_than") { + // bool logic + return DataType::BOOL; + } else { + return phi::promoteTypes(x, y); + } +} + +inline bool NeedTypePromotion(const DataType x, const DataType y) { + // Tensor + Tensor only support type promotion in float, int32, int64 + if ((x != y) && (is_support_float(x) || is_support_int(x)) && + (is_support_float(y) || is_support_int(y))) { + return true; + } else { + return false; + } +} + } // namespace phi From 3a996f8ce998f408b6e039dc72ca2e750e8b503d Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 30 Nov 2023 03:28:34 +0000 Subject: [PATCH 16/27] fix --- paddle/phi/common/type_promotion_table.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/common/type_promotion_table.h b/paddle/phi/common/type_promotion_table.h index a1fdde9c6f7e9f..c7e412a2f5ffa2 100644 --- a/paddle/phi/common/type_promotion_table.h +++ b/paddle/phi/common/type_promotion_table.h @@ -47,7 +47,7 @@ inline int DataTypeToNum(const DataType& dtype) { } } -inline static DataType promoteTypes(DataType a, DataType b) { +inline static DataType promoteTypes(DataType x, DataType y) { constexpr auto u1 = DataType::UINT8; constexpr auto i1 = DataType::INT8; constexpr auto i2 = DataType::INT16; @@ -79,7 +79,7 @@ inline static DataType promoteTypes(DataType a, DataType b) { /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, }; - return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; + return _promoteTypesLookup[DataTypeToNum(x)][DataTypeToNum(y)]; } static inline bool is_support_float(DataType dtype) { From 6784ab7b3bc9f39324d3ced4d54404fa74e1de2a Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 30 Nov 2023 03:34:27 +0000 Subject: [PATCH 17/27] remove deal for int. --- paddle/phi/common/type_promotion_table.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/phi/common/type_promotion_table.h b/paddle/phi/common/type_promotion_table.h index c7e412a2f5ffa2..9a3ebe88f50775 100644 --- a/paddle/phi/common/type_promotion_table.h +++ b/paddle/phi/common/type_promotion_table.h @@ -112,9 +112,8 @@ inline phi::DataType GetPromoteDtype(const std::string& op_name, } inline bool NeedTypePromotion(const DataType x, const DataType y) { - // Tensor + Tensor only support type promotion in float, int32, int64 - if ((x != y) && (is_support_float(x) || is_support_int(x)) && - (is_support_float(y) || is_support_int(y))) { + // Tensor + Tensor only support type promotion for float type + if ((x != y) && is_support_float(x) && is_support_float(y)) { return true; } else { return false; From defb03599bf5eaa0a37e56c503d0ba180fd4b1a4 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 30 Nov 2023 04:07:44 +0000 Subject: [PATCH 18/27] remove int. --- paddle/phi/common/type_promotion_table.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/phi/common/type_promotion_table.h b/paddle/phi/common/type_promotion_table.h index 9a3ebe88f50775..fdb3f1e717faf2 100644 --- a/paddle/phi/common/type_promotion_table.h +++ b/paddle/phi/common/type_promotion_table.h @@ -91,14 +91,6 @@ static inline bool is_support_float(DataType dtype) { } } -static inline bool is_support_int(phi::DataType dtype) { - if (dtype == DataType::INT32 || dtype == DataType::INT64) { - return true; - } else { - return false; - } -} - inline phi::DataType GetPromoteDtype(const std::string& op_name, const DataType x, const DataType y) { From 49b4cf41bf275d1cfa5074c9731f5feb98034d00 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 30 Nov 2023 06:12:07 +0000 Subject: [PATCH 19/27] only for complie --- .../manual/eager_manual/forwards/multiply_fwd_func.cc | 9 +++++---- .../eager/auto_code_generator/generator/eager_gen.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc index 5542c5e51eac1a..8ab0a1b45669f3 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/api/include/sparse_api.h" +#include "paddle/phi/common/type_promotion_table.h" #include "paddle/phi/core/flags.h" PHI_DECLARE_bool(check_nan_inf); @@ -60,11 +61,11 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, // Type promotion Logic paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {{x}, {y}}; - if (egr::NeedTypePromotion(promote_tensors_vector)) { + if (phi::NeedTypePromotion(x.dtype(), y.dtype())) { VLOG(5) << "got different data type, run type protmotion automatically."; auto op_name = phi::TransToFluidOpName("add"); - auto promotion_type = egr::GetPromoteDtype(op_name, promote_tensors_vector); + auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype()); auto new_x = egr::PromoteCast("x", x, promotion_type); auto new_y = egr::PromoteCast("y", y, promotion_type); @@ -407,11 +408,11 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, // Type promotion Logic paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {{x}, {y}}; - if (egr::NeedTypePromotion(promote_tensors_vector)) { + if (phi::NeedTypePromotion(x.dtype(), y.dtype())) { VLOG(5) << "got different data type, run type protmotion automatically."; auto op_name = phi::TransToFluidOpName("add"); - auto promotion_type = egr::GetPromoteDtype(op_name, promote_tensors_vector); + auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype()); auto new_x = egr::PromoteCast("x", x, promotion_type); auto new_y = egr::PromoteCast("y", y, promotion_type); diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 5e5704bcef43e6..52d1b02fc18cc6 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -459,6 +459,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/phi/core/flags.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/fluid/eager/type_promotion_utils.h" +#include "paddle/phi/common/type_promotion_table.h" PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_string(tensor_operants_mode); {} @@ -525,7 +526,7 @@ class {} : public egr::GradNodeBase {{ """ TYPE_PROMOTION_LOGIC_TEMPLATE = """ paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; - if (egr::NeedTypePromotion(promote_tensors_vector)) {{ + if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {{ VLOG(5) << "got different data type, run type protmotion automatically."; {} {} @@ -1854,7 +1855,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): ) # Forward type promotion logic if forward_api_name in type_promote_white_list: - type_promote_get_dst_dtype_str = "auto promotion_type = egr::GetPromoteDtype(op_name, promote_tensors_vector);\n" + type_promote_get_dst_dtype_str = "auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(),y.dtype());\n" type_promote_vector_optional_list_str = " ".join( type_promote_vector_optional_list ) From 0f0f7b19645657a49306c0032fe0f9342d491fe2 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 30 Nov 2023 07:36:40 +0000 Subject: [PATCH 20/27] fix median / cross_entropy_loss --- python/paddle/base/layers/math_op_patch.py | 33 ++++++++++++++----- python/paddle/base/type_promotion.py | 6 ++++ python/paddle/nn/functional/loss.py | 12 +++++-- python/paddle/tensor/stat.py | 4 ++- .../legacy_test/test_tensor_type_promotion.py | 23 ------------- 5 files changed, 43 insertions(+), 35 deletions(-) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 589d29799f90d9..caa91c1f4f597d 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -32,6 +32,15 @@ compare_ops = ['__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__'] +SUPPORT_PROMOTION_OPS = [ + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", +] + EXPRESSION_MAP = { "__add__": "A + B", "__radd__": "A += B", @@ -523,14 +532,22 @@ def __impl__(self, other_var): rhs_dtype = safe_get_dtype(other_var) if lhs_dtype != rhs_dtype: - from ..type_promotion import get_result_dtype, is_support_float - - if is_support_float(lhs_dtype) and is_support_float(rhs_dtype): - common_dtype = get_result_dtype(lhs_dtype, rhs_dtype) - if rhs_dtype != common_dtype: - other_var = astype(other_var, common_dtype) - if lhs_dtype != common_dtype: - self = astype(self, common_dtype) + if method_name in SUPPORT_PROMOTION_OPS: + from ..type_promotion import ( + get_result_dtype, + is_support_float_and_complex, + ) + + if is_support_float_and_complex( + lhs_dtype + ) and is_support_float_and_complex(rhs_dtype): + common_dtype = get_result_dtype(lhs_dtype, rhs_dtype) + if rhs_dtype != common_dtype: + other_var = astype(other_var, common_dtype) + if lhs_dtype != common_dtype: + self = astype(self, common_dtype) + else: + other_var = astype(other_var, lhs_dtype) if reverse: tmp = self diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py index 4080c245845ce7..2fc5ba71cf80ca 100644 --- a/python/paddle/base/type_promotion.py +++ b/python/paddle/base/type_promotion.py @@ -60,6 +60,8 @@ SUPPORT_FLOAT = [dtype.float16, dtype.float32, dtype.float64, dtype.bfloat16] SUPPORT_INT = [dtype.int32, dtype.int64] +SUPPORT_COMPLEX = [dtype.complex64, dtype.complex128] +SUPPORT_FLOAT_AND_COMPLEX = SUPPORT_FLOAT + SUPPORT_COMPLEX def get_result_dtype(x_dtype, y_dtype): @@ -78,5 +80,9 @@ def is_support_float(dtype): return dtype in SUPPORT_FLOAT +def is_support_float_and_complex(dtype): + return dtype in SUPPORT_FLOAT_AND_COMPLEX + + def is_support_int(dtype): return dtype in SUPPORT_INT diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f16115e66084e0..21712fbd9014cc 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3047,20 +3047,26 @@ def cross_entropy( if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = paddle.sum(mask, name=name) - ret = out_sum / (count + paddle.equal(count, 0.0)) + ret = out_sum / ( + count + paddle.equal(count, 0.0).astype(count.dtype) + ) else: mask = paddle.cast(mask, weight_gather_reshape.dtype) weight_ignored = paddle.multiply( mask, weight_gather_reshape ) weight_sum = paddle.sum(weight_ignored, name=name) - ret = out_sum / (weight_sum + paddle.equal(weight_sum, 0.0)) + ret = out_sum / ( + weight_sum + + paddle.equal(weight_sum, 0.0).astype(weight_sum.dtype) + ) return ret elif weight is not None: out_sum = paddle.sum(out, name=name) total_weight = paddle.sum(weight_gather_reshape) return out_sum / ( - total_weight + paddle.equal(total_weight, 0.0) + total_weight + + paddle.equal(total_weight, 0.0).astype(total_weight.dtype) ) else: return paddle.mean(out, name=name) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index d7bcc48c8fa451..2cc473d1e3b318 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -449,7 +449,9 @@ def median(x, axis=None, keepdim=False, name=None): dtype=dtype, ) out_tensor = out_tensor + paddle.sum( - paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True + paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype), + axis=axis, + keepdim=True, ) if is_flatten: if keepdim: diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index 6eae66b3120389..53fec9ceb42fb6 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -297,25 +297,8 @@ def run_api(self): create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool') create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool') -create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'complex64', 'bool') -create_test_case( - TestOperatorOverloadGTInStatic, 'float16', 'complex128', 'bool' -) create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool') -create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'complex64', 'bool') -create_test_case( - TestOperatorOverloadGTInStatic, 'float32', 'complex128', 'bool' -) - -create_test_case(TestOperatorOverloadGTInStatic, 'float64', 'complex64', 'bool') -create_test_case( - TestOperatorOverloadGTInStatic, 'float64', 'complex128', 'bool' -) - -create_test_case( - TestOperatorOverloadGTInStatic, 'complex64', 'complex128', 'bool' -) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( @@ -327,12 +310,6 @@ def run_api(self): create_test_case( TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool' ) - create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'complex64', 'bool' - ) - create_test_case( - TestOperatorOverloadGTInStatic, 'bfloat16', 'complex128', 'bool' - ) if __name__ == '__main__': From bfc51fdaa500d033640e46ddc44d3c8671b047a8 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Thu, 30 Nov 2023 09:19:03 +0000 Subject: [PATCH 21/27] keep old illogical logic for compatibility reasons --- python/paddle/base/layers/math_op_patch.py | 5 +++++ python/paddle/nn/functional/loss.py | 12 +++--------- python/paddle/tensor/stat.py | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index caa91c1f4f597d..9e76a93b4dd919 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -546,6 +546,11 @@ def __impl__(self, other_var): other_var = astype(other_var, common_dtype) if lhs_dtype != common_dtype: self = astype(self, common_dtype) + else: + # NOTE(zoooo0820): Currently, we still keep the old illogical \ + # logic for compatibility reasons + other_var = astype(other_var, lhs_dtype) + else: other_var = astype(other_var, lhs_dtype) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 21712fbd9014cc..f16115e66084e0 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -3047,26 +3047,20 @@ def cross_entropy( if weight is None: mask = paddle.cast(mask, dtype=out_sum.dtype) count = paddle.sum(mask, name=name) - ret = out_sum / ( - count + paddle.equal(count, 0.0).astype(count.dtype) - ) + ret = out_sum / (count + paddle.equal(count, 0.0)) else: mask = paddle.cast(mask, weight_gather_reshape.dtype) weight_ignored = paddle.multiply( mask, weight_gather_reshape ) weight_sum = paddle.sum(weight_ignored, name=name) - ret = out_sum / ( - weight_sum - + paddle.equal(weight_sum, 0.0).astype(weight_sum.dtype) - ) + ret = out_sum / (weight_sum + paddle.equal(weight_sum, 0.0)) return ret elif weight is not None: out_sum = paddle.sum(out, name=name) total_weight = paddle.sum(weight_gather_reshape) return out_sum / ( - total_weight - + paddle.equal(total_weight, 0.0).astype(total_weight.dtype) + total_weight + paddle.equal(total_weight, 0.0) ) else: return paddle.mean(out, name=name) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 2cc473d1e3b318..41f75f5230e785 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -449,7 +449,7 @@ def median(x, axis=None, keepdim=False, name=None): dtype=dtype, ) out_tensor = out_tensor + paddle.sum( - paddle.cast(paddle.isnan(x), dtype=dtype) * x.astype(dtype), + paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True, ) From 8f060a6dfeeb378d73748510eb51abe9e1f750f0 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 1 Dec 2023 03:52:07 +0000 Subject: [PATCH 22/27] pybind the type_promotion function; remove python function; remove float-complex test --- paddle/fluid/pybind/pybind.cc | 17 ++++ python/paddle/base/__init__.py | 1 - python/paddle/base/core.py | 3 + python/paddle/base/layers/math_op_patch.py | 13 +-- python/paddle/base/type_promotion.py | 88 ------------------- .../legacy_test/test_tensor_type_promotion.py | 85 ------------------ 6 files changed, 24 insertions(+), 183 deletions(-) delete mode 100644 python/paddle/base/type_promotion.py diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0674edb09185dd..e155b529c2ad3e 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -204,6 +204,7 @@ limitations under the License. */ #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" +#include "paddle/phi/common/type_promotion_table.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" @@ -883,6 +884,22 @@ PYBIND11_MODULE(libpaddle, m) { &paddle::prim::PrimCommonUtils::SetTargetGradName); m.def("set_num_threads", &platform::SetNumThreads); + m.def("need_type_promotion", + [](framework::proto::VarType::Type type_x, + framework::proto::VarType::Type type_y) { + return phi::NeedTypePromotion(framework::TransToPhiDataType(type_x), + framework::TransToPhiDataType(type_y)); + }); + m.def("get_promote_dtype", + [](const std::string &op_name, + framework::proto::VarType::Type type_x, + framework::proto::VarType::Type type_y) { + return framework::TransToProtoVarType( + phi::GetPromoteDtype(op_name, + framework::TransToPhiDataType(type_x), + framework::TransToPhiDataType(type_y))); + }); + m.def("disable_signal_handler", &DisableSignalHandler); m.def("clear_gradients", diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 19342a4ca1f8f6..7e5ac9c1d92c44 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -126,7 +126,6 @@ HeterXpuTrainer, ) from .backward import append_backward -from . import type_promotion Tensor = LoDTensor enable_imperative = enable_dygraph diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index c491d90c43a919..12a719765d6727 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -342,6 +342,9 @@ def to_list(s): _set_prim_target_grad_name, ) + # type promotion + from .libpaddle import need_type_promotion, get_promote_dtype # noqa: F401 + # isort: on if sys.platform != 'win32': from .libpaddle import ( # noqa: F401 diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 9e76a93b4dd919..42c5ff743939d0 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -533,15 +533,10 @@ def __impl__(self, other_var): if lhs_dtype != rhs_dtype: if method_name in SUPPORT_PROMOTION_OPS: - from ..type_promotion import ( - get_result_dtype, - is_support_float_and_complex, - ) - - if is_support_float_and_complex( - lhs_dtype - ) and is_support_float_and_complex(rhs_dtype): - common_dtype = get_result_dtype(lhs_dtype, rhs_dtype) + if core.need_type_promotion(lhs_dtype, rhs_dtype): + common_dtype = core.get_promote_dtype( + op_type, lhs_dtype, rhs_dtype + ) if rhs_dtype != common_dtype: other_var = astype(other_var, common_dtype) if lhs_dtype != common_dtype: diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py deleted file mode 100644 index 2fc5ba71cf80ca..00000000000000 --- a/python/paddle/base/type_promotion.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from paddle.framework import dtype - -u1 = dtype.uint8 -i1 = dtype.int8 -i2 = dtype.int16 -i4 = dtype.int32 -i8 = dtype.int64 -f2 = dtype.float16 -f4 = dtype.float32 -f8 = dtype.float64 -c4 = dtype.complex64 -c8 = dtype.complex128 -b1 = dtype.bool -bf = dtype.bfloat16 - - -Number = { - dtype.uint8: 0, - dtype.int8: 1, - dtype.int16: 2, - dtype.int32: 3, - dtype.int64: 4, - dtype.float16: 5, - dtype.float32: 6, - dtype.float64: 7, - dtype.complex64: 8, - dtype.complex128: 9, - dtype.bool: 10, - dtype.bfloat16: 11, -} - -promoteTypesLookup = [ - [u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf], - [i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf], - [i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf], - [i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf], - [i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf], - [f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4], - [f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4], - [f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8], - [c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4], - [c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8], - [u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf], - [bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf], -] - - -SUPPORT_FLOAT = [dtype.float16, dtype.float32, dtype.float64, dtype.bfloat16] -SUPPORT_INT = [dtype.int32, dtype.int64] -SUPPORT_COMPLEX = [dtype.complex64, dtype.complex128] -SUPPORT_FLOAT_AND_COMPLEX = SUPPORT_FLOAT + SUPPORT_COMPLEX - - -def get_result_dtype(x_dtype, y_dtype): - if x_dtype == y_dtype: - return x_dtype - else: - try: - return promoteTypesLookup[Number[x_dtype]][Number[y_dtype]] - except: - raise TypeError( - f"got unsupport dtype for type promotion: {x_dtype} and {y_dtype}." - ) - - -def is_support_float(dtype): - return dtype in SUPPORT_FLOAT - - -def is_support_float_and_complex(dtype): - return dtype in SUPPORT_FLOAT_AND_COMPLEX - - -def is_support_int(dtype): - return dtype in SUPPORT_INT diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index 53fec9ceb42fb6..c47bfe8e5d1d56 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -101,33 +101,11 @@ def test_dtype_is_expected(self): create_test_case( TestOperatorOverloadAddInStatic, 'float16', 'float64', 'float64' ) -create_test_case( - TestOperatorOverloadAddInStatic, 'float16', 'complex64', 'complex64' -) -create_test_case( - TestOperatorOverloadAddInStatic, 'float16', 'complex128', 'complex128' -) create_test_case( TestOperatorOverloadAddInStatic, 'float32', 'float64', 'float64' ) -create_test_case( - TestOperatorOverloadAddInStatic, 'float32', 'complex64', 'complex64' -) -create_test_case( - TestOperatorOverloadAddInStatic, 'float32', 'complex128', 'complex128' -) -create_test_case( - TestOperatorOverloadAddInStatic, 'float64', 'complex64', 'complex128' -) -create_test_case( - TestOperatorOverloadAddInStatic, 'float64', 'complex128', 'complex128' -) - -create_test_case( - TestOperatorOverloadAddInStatic, 'complex64', 'complex128', 'complex128' -) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( @@ -139,12 +117,6 @@ def test_dtype_is_expected(self): create_test_case( TestOperatorOverloadAddInStatic, 'bfloat16', 'float64', 'float64' ) - create_test_case( - TestOperatorOverloadAddInStatic, 'bfloat16', 'complex64', 'complex64' - ) - create_test_case( - TestOperatorOverloadAddInStatic, 'bfloat16', 'complex128', 'complex128' - ) class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic): @@ -166,33 +138,11 @@ def run_api(self): create_test_case( TestOperatorOverloadSubInStatic, 'float16', 'float64', 'float64' ) -create_test_case( - TestOperatorOverloadSubInStatic, 'float16', 'complex64', 'complex64' -) -create_test_case( - TestOperatorOverloadSubInStatic, 'float16', 'complex128', 'complex128' -) create_test_case( TestOperatorOverloadSubInStatic, 'float32', 'float64', 'float64' ) -create_test_case( - TestOperatorOverloadSubInStatic, 'float32', 'complex64', 'complex64' -) -create_test_case( - TestOperatorOverloadSubInStatic, 'float32', 'complex128', 'complex128' -) - -create_test_case( - TestOperatorOverloadSubInStatic, 'float64', 'complex64', 'complex128' -) -create_test_case( - TestOperatorOverloadSubInStatic, 'float64', 'complex128', 'complex128' -) -create_test_case( - TestOperatorOverloadSubInStatic, 'complex64', 'complex128', 'complex128' -) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( @@ -204,12 +154,6 @@ def run_api(self): create_test_case( TestOperatorOverloadSubInStatic, 'bfloat16', 'float64', 'float64' ) - create_test_case( - TestOperatorOverloadSubInStatic, 'bfloat16', 'complex64', 'complex64' - ) - create_test_case( - TestOperatorOverloadSubInStatic, 'bfloat16', 'complex128', 'complex128' - ) class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): @@ -231,33 +175,10 @@ def run_api(self): create_test_case( TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' ) -create_test_case( - TestOperatorOverloadMulInStatic, 'float16', 'complex64', 'complex64' -) -create_test_case( - TestOperatorOverloadMulInStatic, 'float16', 'complex128', 'complex128' -) create_test_case( TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' ) -create_test_case( - TestOperatorOverloadMulInStatic, 'float32', 'complex64', 'complex64' -) -create_test_case( - TestOperatorOverloadMulInStatic, 'float32', 'complex128', 'complex128' -) - -create_test_case( - TestOperatorOverloadMulInStatic, 'float64', 'complex64', 'complex128' -) -create_test_case( - TestOperatorOverloadMulInStatic, 'float64', 'complex128', 'complex128' -) - -create_test_case( - TestOperatorOverloadMulInStatic, 'complex64', 'complex128', 'complex128' -) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( @@ -269,12 +190,6 @@ def run_api(self): create_test_case( TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' ) - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'complex64', 'complex64' - ) - create_test_case( - TestOperatorOverloadMulInStatic, 'bfloat16', 'complex128', 'complex128' - ) class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic): From b893736d14c3ddbf27b1c402f1f0f603a0f0b849 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 1 Dec 2023 06:16:48 +0000 Subject: [PATCH 23/27] remove change for dygraph --- .../forwards/multiply_fwd_func.cc | 32 ----- .../generator/eager_gen.py | 84 +---------- paddle/fluid/pybind/eager_math_op_patch.cc | 133 +++++++++++++++++- python/paddle/tensor/stat.py | 4 +- .../test_math_op_patch_var_base.py | 82 +---------- 5 files changed, 131 insertions(+), 204 deletions(-) diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc index 8ab0a1b45669f3..092620120cae19 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc @@ -19,11 +19,9 @@ #include "paddle/fluid/eager/eager_amp_auto_cast.h" #include "paddle/fluid/eager/eager_layout_auto_tune.h" #include "paddle/fluid/eager/nan_inf_utils.h" -#include "paddle/fluid/eager/type_promotion_utils.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/api/include/sparse_api.h" -#include "paddle/phi/common/type_promotion_table.h" #include "paddle/phi/core/flags.h" PHI_DECLARE_bool(check_nan_inf); @@ -58,21 +56,6 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } } - // Type promotion Logic - paddle::small_vector, egr::kSlotSmallVectorSize> - promote_tensors_vector = {{x}, {y}}; - if (phi::NeedTypePromotion(x.dtype(), y.dtype())) { - VLOG(5) << "got different data type, run type protmotion automatically."; - auto op_name = phi::TransToFluidOpName("add"); - - auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype()); - - auto new_x = egr::PromoteCast("x", x, promotion_type); - auto new_y = egr::PromoteCast("y", y, promotion_type); - - return multiply_ad_func(new_x, new_y); - } - // Layout autotune if (egr::Controller::Instance().UseLayoutAutoTune()) { @@ -405,21 +388,6 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } } - // Type promotion Logic - paddle::small_vector, egr::kSlotSmallVectorSize> - promote_tensors_vector = {{x}, {y}}; - if (phi::NeedTypePromotion(x.dtype(), y.dtype())) { - VLOG(5) << "got different data type, run type protmotion automatically."; - auto op_name = phi::TransToFluidOpName("add"); - - auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype()); - - auto new_x = egr::PromoteCast("x", x, promotion_type); - auto new_y = egr::PromoteCast("y", y, promotion_type); - - return multiply_ad_func(new_x, new_y); - } - // Layout autotune if (egr::Controller::Instance().UseLayoutAutoTune()) { diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 52d1b02fc18cc6..ff1758e3ef93a4 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -75,13 +75,6 @@ "tanh_triple_grad", ] -# white ops list whose kernel can automaically do type promotion. -type_promote_white_list = [ - "add", - "subtract", - "greater_than", -] - # dict of special api that forward api's output will affect bacward api's output # bacward api's output usually affected by backward api's input special_prune_dict = { @@ -254,8 +247,6 @@ class {} : public egr::GradNodeBase {{ // Dygraph Record Event {} // AMP Logic -{} - // Type promotion Logic {} // Layout autotune {} @@ -324,8 +315,6 @@ class {} : public egr::GradNodeBase {{ // Dygraph Record Event {} // AMP Logic -{} - // Type promotion Logic {} // Layout autotune {} @@ -458,8 +447,7 @@ class {} : public egr::GradNodeBase {{ #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/api/lib/data_transform.h" -#include "paddle/fluid/eager/type_promotion_utils.h" -#include "paddle/phi/common/type_promotion_table.h" + PHI_DECLARE_bool(check_nan_inf); PHI_DECLARE_string(tensor_operants_mode); {} @@ -524,19 +512,6 @@ class {} : public egr::GradNodeBase {{ }} }} """ - -TYPE_PROMOTION_LOGIC_TEMPLATE = """ paddle::small_vector, egr::kSlotSmallVectorSize> promote_tensors_vector = {}; - if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {{ - VLOG(5) << "got different data type, run type protmotion automatically."; - {} - {} - {} - {} - {} - }} -""" - - LAYOUT_LOGIC_TEMPLATE = """ if (egr::Controller::Instance().UseLayoutAutoTune()) {{ paddle::small_vector, egr::kSlotSmallVectorSize> tensors_vector = {}; @@ -1488,10 +1463,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_tensors_vector_optional_list = [] amp_autocast_list = [] amp_autocast_optional_list = [] - type_promote_vector_list = [] - type_promote_vector_optional_list = [] - type_promote_list = [] - type_promote_optional_list = [] layout_autotune_list = [] layout_autotune_optional_list = [] layout_tensors_vector_optional_list = [] @@ -1518,12 +1489,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_optional_list.append( f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) - type_promote_vector_optional_list.append( - f"if ({name}) promote_tensors_vector.push_back({{ *{name} }});\n" - ) - type_promote_optional_list.append( - f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" - ) layout_tensors_vector_optional_list.append( f"if ({name}) tensors_vector.push_back({{ *{name} }});\n" ) @@ -1547,10 +1512,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_list.append( f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) - type_promote_vector_list.append(f"{name}") - type_promote_list.append( - f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" - ) layout_autotune_list.append( f"auto new_{name} = transformer->TransInTensor(\"{name}\", {name});\n" ) @@ -1572,12 +1533,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_optional_list.append( f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) - type_promote_vector_optional_list.append( - f"if ({name}) promote_tensors_vector.push_back( *{name} );\n" - ) - type_promote_optional_list.append( - f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" - ) layout_autotune_optional_list.append( f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n" ) @@ -1594,10 +1549,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_list.append( f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n" ) - type_promote_vector_list.append(f"{name}") - type_promote_list.append( - f"auto new_{name} = egr::PromoteCast(\"{name}\", {name}, promotion_type);\n" - ) layout_autotune_list.append( f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n" ) @@ -1853,31 +1804,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_autocast_list_str, amp_call_str, ) - # Forward type promotion logic - if forward_api_name in type_promote_white_list: - type_promote_get_dst_dtype_str = "auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(),y.dtype());\n" - type_promote_vector_optional_list_str = " ".join( - type_promote_vector_optional_list - ) - type_promote_list_str = ( - " ".join(type_promote_list) - + " " - + " ".join(type_promote_optional_list) - ) - type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format( - amp_tensors_vector_list_str, - kernel_trans2_op_name_str, - type_promote_vector_optional_list_str, - type_promote_get_dst_dtype_str, - type_promote_list_str, - amp_call_str, - ) - else: - type_promotion_logic_str = ( - "\n VLOG(5) << \" No Promotion for {} api. \"; ".format( - forward_ad_function_name - ) - ) + # Forward layout autotune layout_autotune_list_str = " ".join( layout_autotune_list @@ -1914,11 +1841,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_logic_str = "\n VLOG(7) << \" No AMP for {} because it has no input. \"; ".format( forward_ad_function_name ) - type_promotion_logic_str = ( - "\n VLOG(7) << \" No Promotion for {} api. \"; ".format( - forward_ad_function_name - ) - ) self.forward_definition_str += ( FORWARD_ONLY_FUNCTION_TEMPLATE.format( returns_type_str, @@ -1927,7 +1849,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): forward_api_name, dygraph_event_str, amp_logic_str, - type_promotion_logic_str, layout_logic_str, forward_api_name, before_log_str, @@ -1950,7 +1871,6 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): forward_api_name, dygraph_event_str, amp_logic_str, - type_promotion_logic_str, layout_logic_str, inputs_autograd_meta_str, forward_api_name, diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 524d08eae2fe20..aa7a27db207364 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -252,7 +252,36 @@ static PyObject* tensor__add__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types move to add_ad_func + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + // note: only op_type in _supported_promote_complex_types_ should promote + // dtype + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + VLOG(6) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } // 4. calculation VLOG(6) << "Calling add_ad_func in tensor__add__method"; @@ -329,8 +358,34 @@ static PyObject* tensor__sub__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types move to subtract_ad_func - + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + VLOG(6) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } // 4. calculation VLOG(6) << "Calling subtract_ad_func in tensor__sub__method"; { @@ -405,7 +460,34 @@ static PyObject* tensor__rsub__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types move to subtract_ad_func + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + VLOG(6) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } // 4. calculation VLOG(6) << "Calling subtract_ad_func in tensor__rsub__method"; @@ -486,7 +568,36 @@ static PyObject* tensor__mul__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types move to multiply_ad_func + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + // note: only op_type in _supported_promote_complex_types_ should promote + // dtype + if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() || + _complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) { + phi::DataType promote_dtype = + framework::TransToPhiDataType(framework::PromoteTypesIfComplexExists( + framework::TransToProtoVarType(lhs_dtype), + framework::TransToProtoVarType(rhs_dtype))); + if (lhs_dtype != promote_dtype) { + // cast + eager_gil_scoped_release guard; + self_tensor = cast_ad_func(self_tensor, promote_dtype); + } + if (rhs_dtype != promote_dtype) { + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, promote_dtype); + } + } else { + VLOG(6) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } + } // 4. calculation VLOG(6) << "Calling multiply_ad_func in tensor__mul__method"; @@ -816,7 +927,17 @@ static PyObject* tensor__gt__method(TensorObject* self, ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); } - // 3. promote types move to greater_than_ad_func + // 3. promote types or unify right var type to left var + phi::DataType lhs_dtype = self_tensor.dtype(); + phi::DataType rhs_dtype = other_tensor.dtype(); + if (lhs_dtype != rhs_dtype) { + VLOG(6) << "The dtype of left and right Tensor are not the same, left " + "dtype is " + << lhs_dtype << ", but right dtype is " << rhs_dtype + << ", the right dtype will convert to " << lhs_dtype; + eager_gil_scoped_release guard; + other_tensor = cast_ad_func(other_tensor, lhs_dtype); + } // 4. calculation VLOG(6) << "Calling greater_than_ad_func in tensor__gt__method"; diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 41f75f5230e785..d7bcc48c8fa451 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -449,9 +449,7 @@ def median(x, axis=None, keepdim=False, name=None): dtype=dtype, ) out_tensor = out_tensor + paddle.sum( - paddle.cast(paddle.isnan(x), dtype=dtype) * x, - axis=axis, - keepdim=True, + paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True ) if is_flatten: if keepdim: diff --git a/test/legacy_test/test_math_op_patch_var_base.py b/test/legacy_test/test_math_op_patch_var_base.py index f7a63a0ee3d91f..af5fbd9ba9ca1c 100644 --- a/test/legacy_test/test_math_op_patch_var_base.py +++ b/test/legacy_test/test_math_op_patch_var_base.py @@ -35,28 +35,6 @@ def test_add(self): res = a + b np.testing.assert_array_equal(res.numpy(), a_np + b_np) - def test_type_promotion_add_F_F(self): - a_np = np.random.random(self.shape).astype(np.float32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a + b - res_t = b + a - np.testing.assert_array_equal(res_t.numpy(), res.numpy()) - np.testing.assert_array_equal(res.numpy(), a_np + b_np) - - def test_type_promotion_add_F_I(self): - a_np = np.random.random(self.shape).astype(np.int32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a + b - res_t = b + a - np.testing.assert_array_equal(res_t.numpy(), res.numpy()) - np.testing.assert_array_equal(res.numpy(), a_np + b_np) - def test_sub(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) @@ -65,25 +43,7 @@ def test_sub(self): b = base.dygraph.to_variable(b_np) res = a - b np.testing.assert_array_equal(res.numpy(), a_np - b_np) - - def test_type_promotion_sub_F_F(self): - a_np = np.random.random(self.shape).astype(np.float32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a - b - np.testing.assert_array_equal(res.numpy(), a_np - b_np) - - def test_type_promotion_sub_F_I(self): - a_np = np.random.random(self.shape).astype(np.int32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a - b - np.testing.assert_array_equal(res.numpy(), a_np - b_np) - + def test_mul(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) @@ -93,28 +53,6 @@ def test_mul(self): res = a * b np.testing.assert_array_equal(res.numpy(), a_np * b_np) - def test_type_promotion_mul_F_F(self): - a_np = np.random.random(self.shape).astype(np.float32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a * b - res_t = b * a - np.testing.assert_array_equal(res_t.numpy(), res.numpy()) - np.testing.assert_array_equal(res.numpy(), a_np * b_np) - - def test_type_promotion_mul_F_I(self): - a_np = np.random.random(self.shape).astype(np.int32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a * b - res_t = b * a - np.testing.assert_array_equal(res_t.numpy(), res.numpy()) - np.testing.assert_array_equal(res.numpy(), a_np * b_np) - def test_div(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) @@ -281,24 +219,6 @@ def test_greater_than(self): res = a > b np.testing.assert_array_equal(res.numpy(), a_np > b_np) - def test_type_promotion_greater_than_F_F(self): - a_np = np.random.random(self.shape).astype(np.float32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a > b - np.testing.assert_array_equal(res.numpy(), a_np > b_np) - - def test_type_promotion_greater_than_F_I(self): - a_np = np.random.random(self.shape).astype(np.int32) - b_np = np.random.random(self.shape).astype(np.float16) - with base.dygraph.guard(): - a = base.dygraph.to_variable(a_np) - b = base.dygraph.to_variable(b_np) - res = a > b - np.testing.assert_array_equal(res.numpy(), a_np > b_np) - def test_greater_equal(self): a_np = np.random.random(self.shape).astype(self.dtype) b_np = np.random.random(self.shape).astype(self.dtype) From a06b28cf65c0d89c85cfa57c365d619729b6de52 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Fri, 1 Dec 2023 06:36:46 +0000 Subject: [PATCH 24/27] rename type_promotion_table.h -> data_type_promotion.h --- paddle/fluid/eager/type_promotion_utils.h | 2 +- paddle/fluid/pybind/pybind.cc | 2 +- .../common/{type_promotion_table.h => data_type_promotion.h} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename paddle/phi/common/{type_promotion_table.h => data_type_promotion.h} (100%) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 0d890b58bf50b9..8504f8b16c6a81 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -14,7 +14,7 @@ #pragma once #include "paddle/fluid/eager/api/utils/global_utils.h" -#include "paddle/phi/common/type_promotion_table.h" +#include "paddle/phi/common/data_type_promotion.h" namespace egr { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index e155b529c2ad3e..de889deade5786 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -204,7 +204,7 @@ limitations under the License. */ #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" -#include "paddle/phi/common/type_promotion_table.h" +#include "paddle/phi/common/data_type_promotion.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" diff --git a/paddle/phi/common/type_promotion_table.h b/paddle/phi/common/data_type_promotion.h similarity index 100% rename from paddle/phi/common/type_promotion_table.h rename to paddle/phi/common/data_type_promotion.h From fbb47040685e61e6b0810597424c4c7bf43960fb Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Mon, 4 Dec 2023 10:03:43 +0000 Subject: [PATCH 25/27] convert dtype in Block.append_op; support where op --- python/paddle/base/framework.py | 49 ++++++++- python/paddle/base/layers/math_op_patch.py | 22 ++-- test/legacy_test/test_where_op.py | 114 ++++++++++++++++++++- 3 files changed, 170 insertions(+), 15 deletions(-) diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 6723f5e0f40f96..f7eac26e282b24 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -56,6 +56,14 @@ CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() _global_flags_ = core.globals() +# TODO(zoooo0820): unify this dict of dygraph and static at Pybind +SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { + "elementwise_add": ['X', 'Y'], + "elementwise_sub": ['X', 'Y'], + "elementwise_mul": ['X', 'Y'], + "where": ['X', 'Y'], +} + def _global_flags(): return _global_flags_ @@ -4383,6 +4391,38 @@ def _is_inited_by(block, var): param.stop_gradient = stop_gradient return param + def _type_promotion_for_inputs(self, op_type, inputs): + need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get( + op_type, None + ) + if need_transed_var_names is None: + return + + all_dtypes = [] + for input_name in inputs.keys(): + if input_name in need_transed_var_names: + var_dtype = ( + inputs[input_name][0].dtype + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].dtype + ) + all_dtypes.append(var_dtype) + + common_dtype = core.get_promote_dtype(op_type, *all_dtypes) + for input_name in inputs.keys(): + if input_name in need_transed_var_names: + var_dtype = ( + inputs[input_name][0].dtype + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].dtype + ) + if var_dtype != common_dtype: + inputs[input_name] = ( + [inputs[input_name][0].astype(common_dtype)] + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].astype(common_dtype) + ) + def append_op(self, *args, **kwargs): """ Appends a new Operator according to the giving arguments. @@ -4394,6 +4434,7 @@ def append_op(self, *args, **kwargs): op_type = kwargs.get("type", None) if in_dygraph_mode(): attrs = kwargs.get("attrs", {}) + inputs = kwargs.get("inputs", {}) warnings.warn( "Op `%s` is executed through `append_op` under the dynamic mode, " "the corresponding API implementation needs to be upgraded to " @@ -4409,6 +4450,8 @@ def append_op(self, *args, **kwargs): attrs=attrs, ) + self._type_promotion_for_inputs(op_type, inputs) + # record ops in tracer rather than blocks # # TODO(minqiyang): add op stop_gradient support in static graph mode too. @@ -4416,7 +4459,7 @@ def append_op(self, *args, **kwargs): _dygraph_tracer().trace_op( op_type, - kwargs.get("inputs", {}), + inputs, kwargs.get("outputs", {}), attrs if attrs else {}, kwargs.get("stop_gradient", False), @@ -4440,9 +4483,11 @@ def pass_stop_gradient(ins, outs): if isinstance(var, Variable): var.stop_gradient = True - op_desc = self.desc.append_op() inputs = kwargs.get("inputs", None) outputs = kwargs.get("outputs", None) + + self._type_promotion_for_inputs(op_type, inputs) + op_desc = self.desc.append_op() # NOTE(Aurelius84): In case of @to_static, all Tensor(s) should # be converted into Variable(s) with same name and block location. # This is ONE and ONLY logic of type transformation of dy2static. diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 42c5ff743939d0..a995424dd730c9 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -532,20 +532,11 @@ def __impl__(self, other_var): rhs_dtype = safe_get_dtype(other_var) if lhs_dtype != rhs_dtype: + # NOTE(zoooo0820): Currently, we still keep the old illogical + # logic for compatibility reasons if method_name in SUPPORT_PROMOTION_OPS: - if core.need_type_promotion(lhs_dtype, rhs_dtype): - common_dtype = core.get_promote_dtype( - op_type, lhs_dtype, rhs_dtype - ) - if rhs_dtype != common_dtype: - other_var = astype(other_var, common_dtype) - if lhs_dtype != common_dtype: - self = astype(self, common_dtype) - else: - # NOTE(zoooo0820): Currently, we still keep the old illogical \ - # logic for compatibility reasons + if not core.need_type_promotion(lhs_dtype, rhs_dtype): other_var = astype(other_var, lhs_dtype) - else: other_var = astype(other_var, lhs_dtype) @@ -563,6 +554,13 @@ def __impl__(self, other_var): # NOTE(zhiqiu): the output of compare operator should be bool. if method_name in compare_ops: out = create_new_tmp_var(current_block(self), dtype="bool") + elif method_name in SUPPORT_PROMOTION_OPS: + out = create_new_tmp_var( + current_block(self), + dtype=core.get_promote_dtype( + op_type, self.dtype, other_var.dtype + ), + ) else: out = create_new_tmp_var( current_block(self), dtype=safe_get_dtype(self) diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 89328610e92722..e08d4b0b4763e3 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from op_test import OpTest, convert_float_to_uint16 +from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float import paddle from paddle import base @@ -318,6 +318,61 @@ def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): expect = np.where(cond_data, x_data, y_data) np.testing.assert_array_equal(out[0], expect) + def __test_where_with_type_promotion( + self, x_dtype, y_dtype, expeced_dtype=None + ): + paddle.enable_static() + main_program = paddle.static.Program() + shape = [3, 10] + with paddle.static.program_guard(main_program): + cond = paddle.static.data(name='cond', shape=[3, 10], dtype='bool') + x = paddle.static.data(name='x', shape=shape, dtype=x_dtype) + y = paddle.static.data(name='y', shape=shape, dtype=y_dtype) + cond_data_tmp = np.random.random(size=shape).astype('float32') + cond_data = cond_data_tmp < 0.3 + + if x_dtype != 'bfloat16': + x_data = np.random.random(size=shape).astype(x_dtype) + else: + x_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + if y_dtype != 'bfloat16': + y_data = np.random.random(size=shape).astype(y_dtype) + else: + y_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + result = paddle.where(condition=cond, x=x, y=y) + for use_cuda in [False, True]: + if use_cuda and (not base.core.is_compiled_with_cuda()): + return + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + exe = base.Executor(place) + out = exe.run( + paddle.static.default_main_program(), + feed={'cond': cond_data, 'x': x_data, 'y': y_data}, + fetch_list=[result], + ) + if x_dtype == 'bfloat16' or y_dtype == 'bfloat16': + x_data_convert = ( + convert_uint16_to_float(x_data) + if x_dtype == 'bfloat16' + else x_data + ) + y_data_convert = ( + convert_uint16_to_float(y_data) + if y_dtype == 'bfloat16' + else y_data + ) + expect = np.where(cond_data, x_data_convert, y_data_convert) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype.__str__(), expeced_dtype) + else: + expect = np.where(cond_data, x_data, y_data) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype, expect.dtype) + @test_with_pir_api def test_static_api_broadcast_1(self): cond_shape = [2, 4] @@ -374,6 +429,63 @@ def test_static_api_broadcast_8(self): b_shape = [2, 2, 1] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + def test_static_api_type_promotion_fp16_fp32(self): + x_dtype = 'float16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp16_fp64(self): + x_dtype = 'float16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp32_fp64(self): + x_dtype = 'float32' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp16(self): + x_dtype = 'bfloat16' + y_dtype = 'float16' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp32(self): + x_dtype = 'bfloat16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp64(self): + x_dtype = 'bfloat16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float64') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float64') + class TestWhereDygraphAPI(unittest.TestCase): def test_api(self): From 88aad1e99d4c4a408e009fe0707a54645edede30 Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Mon, 4 Dec 2023 10:32:27 +0000 Subject: [PATCH 26/27] add warnings --- python/paddle/base/framework.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index f7eac26e282b24..b6ea1e2d0d2a3b 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -4409,6 +4409,11 @@ def _type_promotion_for_inputs(self, op_type, inputs): all_dtypes.append(var_dtype) common_dtype = core.get_promote_dtype(op_type, *all_dtypes) + + warnings.warn( + f"The input dtypes of OP {op_type} are {all_dtypes}, the output will be auto-promoted to {common_dtype}" + ) + for input_name in inputs.keys(): if input_name in need_transed_var_names: var_dtype = ( From 12a2bfb45122f9fb9088c0f10b778e6f285d8afe Mon Sep 17 00:00:00 2001 From: zoooo0820 Date: Tue, 5 Dec 2023 03:01:55 +0000 Subject: [PATCH 27/27] only promote if needed --- python/paddle/base/framework.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index b6ea1e2d0d2a3b..4795d98d0496c3 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -4408,25 +4408,25 @@ def _type_promotion_for_inputs(self, op_type, inputs): ) all_dtypes.append(var_dtype) - common_dtype = core.get_promote_dtype(op_type, *all_dtypes) - - warnings.warn( - f"The input dtypes of OP {op_type} are {all_dtypes}, the output will be auto-promoted to {common_dtype}" - ) + if core.need_type_promotion(*all_dtypes): + common_dtype = core.get_promote_dtype(op_type, *all_dtypes) + warnings.warn( + f"The input dtypes of OP {op_type} are {all_dtypes}, the output will be auto-promoted to {common_dtype}" + ) - for input_name in inputs.keys(): - if input_name in need_transed_var_names: - var_dtype = ( - inputs[input_name][0].dtype - if isinstance(inputs[input_name], (list, tuple)) - else inputs[input_name].dtype - ) - if var_dtype != common_dtype: - inputs[input_name] = ( - [inputs[input_name][0].astype(common_dtype)] + for input_name in inputs.keys(): + if input_name in need_transed_var_names: + var_dtype = ( + inputs[input_name][0].dtype if isinstance(inputs[input_name], (list, tuple)) - else inputs[input_name].astype(common_dtype) + else inputs[input_name].dtype ) + if var_dtype != common_dtype: + inputs[input_name] = ( + [inputs[input_name][0].astype(common_dtype)] + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].astype(common_dtype) + ) def append_op(self, *args, **kwargs): """