diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 00313512d2c46..8504f8b16c6a8 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -12,90 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once -#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" + #include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/phi/common/data_type_promotion.h" namespace egr { -inline int DataTypeToNum(const phi::DataType& dtype) { - switch (dtype) { - case phi::DataType::UINT8: - return 0; - case phi::DataType::INT8: - return 1; - case phi::DataType::INT16: - return 2; - case phi::DataType::INT32: - return 3; - case phi::DataType::INT64: - return 4; - case phi::DataType::FLOAT16: - return 5; - case phi::DataType::FLOAT32: - return 6; - case phi::DataType::FLOAT64: - return 7; - case phi::DataType::COMPLEX64: - return 8; - case phi::DataType::COMPLEX128: - return 9; - case phi::DataType::BOOL: - return 10; - case phi::DataType::BFLOAT16: - return 11; - default: - PD_THROW("Invalid enum data type for type promote `", dtype, "`."); - } -} - -static inline bool is_support_float(phi::DataType dtype) { - if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::FLOAT32 || - dtype == phi::DataType::FLOAT64 || dtype == phi::DataType::BFLOAT16) { - return true; - } else { - return false; - } -} - -static inline bool is_support_int(phi::DataType dtype) { - if (dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64) { - return true; +inline paddle::Tensor PromoteCast(const std::string& input_name, + const paddle::Tensor& input, + const phi::DataType& dst_dtype, + bool trace_backward = true) { + if (input.dtype() != dst_dtype) { + return Cast(input, dst_dtype, trace_backward); } else { - return false; + return input; } } -inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) { - constexpr auto u1 = phi::DataType::UINT8; - constexpr auto i1 = phi::DataType::INT8; - constexpr auto i2 = phi::DataType::INT16; - constexpr auto i4 = phi::DataType::INT32; - constexpr auto i8 = phi::DataType::INT64; - constexpr auto f2 = phi::DataType::FLOAT16; - constexpr auto f4 = phi::DataType::FLOAT32; - constexpr auto f8 = phi::DataType::FLOAT64; - constexpr auto c4 = phi::DataType::COMPLEX64; - constexpr auto c8 = phi::DataType::COMPLEX128; - constexpr auto b1 = phi::DataType::BOOL; - constexpr auto bf = phi::DataType::BFLOAT16; - - static constexpr phi::DataType _promoteTypesLookup[12][12] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, - /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, - }; - - return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)]; -} - } // namespace egr diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0674edb09185d..de889deade578 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -204,6 +204,7 @@ limitations under the License. */ #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/include/operants_manager.h" #include "paddle/phi/api/include/tensor_operants.h" +#include "paddle/phi/common/data_type_promotion.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/switch_autotune.h" @@ -883,6 +884,22 @@ PYBIND11_MODULE(libpaddle, m) { &paddle::prim::PrimCommonUtils::SetTargetGradName); m.def("set_num_threads", &platform::SetNumThreads); + m.def("need_type_promotion", + [](framework::proto::VarType::Type type_x, + framework::proto::VarType::Type type_y) { + return phi::NeedTypePromotion(framework::TransToPhiDataType(type_x), + framework::TransToPhiDataType(type_y)); + }); + m.def("get_promote_dtype", + [](const std::string &op_name, + framework::proto::VarType::Type type_x, + framework::proto::VarType::Type type_y) { + return framework::TransToProtoVarType( + phi::GetPromoteDtype(op_name, + framework::TransToPhiDataType(type_x), + framework::TransToPhiDataType(type_y))); + }); + m.def("disable_signal_handler", &DisableSignalHandler); m.def("clear_gradients", diff --git a/paddle/phi/common/data_type_promotion.h b/paddle/phi/common/data_type_promotion.h new file mode 100644 index 0000000000000..fdb3f1e717faf --- /dev/null +++ b/paddle/phi/common/data_type_promotion.h @@ -0,0 +1,115 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/common/data_type.h" +namespace phi { + +inline int DataTypeToNum(const DataType& dtype) { + switch (dtype) { + case DataType::UINT8: + return 0; + case DataType::INT8: + return 1; + case DataType::INT16: + return 2; + case DataType::INT32: + return 3; + case DataType::INT64: + return 4; + case DataType::FLOAT16: + return 5; + case DataType::FLOAT32: + return 6; + case DataType::FLOAT64: + return 7; + case DataType::COMPLEX64: + return 8; + case DataType::COMPLEX128: + return 9; + case DataType::BOOL: + return 10; + case DataType::BFLOAT16: + return 11; + default: + PD_THROW("Invalid enum data type for type promote `", dtype, "`."); + } +} + +inline static DataType promoteTypes(DataType x, DataType y) { + constexpr auto u1 = DataType::UINT8; + constexpr auto i1 = DataType::INT8; + constexpr auto i2 = DataType::INT16; + constexpr auto i4 = DataType::INT32; + constexpr auto i8 = DataType::INT64; + constexpr auto f2 = DataType::FLOAT16; + constexpr auto f4 = DataType::FLOAT32; + constexpr auto f8 = DataType::FLOAT64; + constexpr auto c4 = DataType::COMPLEX64; + constexpr auto c8 = DataType::COMPLEX128; + constexpr auto b1 = DataType::BOOL; + constexpr auto bf = DataType::BFLOAT16; + + const int total_type_num = 12; + + static constexpr DataType + _promoteTypesLookup[total_type_num][total_type_num] = { + /* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf}, + }; + return _promoteTypesLookup[DataTypeToNum(x)][DataTypeToNum(y)]; +} + +static inline bool is_support_float(DataType dtype) { + if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 || + dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) { + return true; + } else { + return false; + } +} + +inline phi::DataType GetPromoteDtype(const std::string& op_name, + const DataType x, + const DataType y) { + // future will deal this by different rule + if (op_name == "greater_than") { + // bool logic + return DataType::BOOL; + } else { + return phi::promoteTypes(x, y); + } +} + +inline bool NeedTypePromotion(const DataType x, const DataType y) { + // Tensor + Tensor only support type promotion for float type + if ((x != y) && is_support_float(x) && is_support_float(y)) { + return true; + } else { + return false; + } +} + +} // namespace phi diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 19342a4ca1f8f..7e5ac9c1d92c4 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -126,7 +126,6 @@ HeterXpuTrainer, ) from .backward import append_backward -from . import type_promotion Tensor = LoDTensor enable_imperative = enable_dygraph diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index c491d90c43a91..12a719765d672 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -342,6 +342,9 @@ def to_list(s): _set_prim_target_grad_name, ) + # type promotion + from .libpaddle import need_type_promotion, get_promote_dtype # noqa: F401 + # isort: on if sys.platform != 'win32': from .libpaddle import ( # noqa: F401 diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 6723f5e0f40f9..4795d98d0496c 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -56,6 +56,14 @@ CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() _global_flags_ = core.globals() +# TODO(zoooo0820): unify this dict of dygraph and static at Pybind +SUPPORT_PROMOTION_OPS_AND_INPUTNAME = { + "elementwise_add": ['X', 'Y'], + "elementwise_sub": ['X', 'Y'], + "elementwise_mul": ['X', 'Y'], + "where": ['X', 'Y'], +} + def _global_flags(): return _global_flags_ @@ -4383,6 +4391,43 @@ def _is_inited_by(block, var): param.stop_gradient = stop_gradient return param + def _type_promotion_for_inputs(self, op_type, inputs): + need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get( + op_type, None + ) + if need_transed_var_names is None: + return + + all_dtypes = [] + for input_name in inputs.keys(): + if input_name in need_transed_var_names: + var_dtype = ( + inputs[input_name][0].dtype + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].dtype + ) + all_dtypes.append(var_dtype) + + if core.need_type_promotion(*all_dtypes): + common_dtype = core.get_promote_dtype(op_type, *all_dtypes) + warnings.warn( + f"The input dtypes of OP {op_type} are {all_dtypes}, the output will be auto-promoted to {common_dtype}" + ) + + for input_name in inputs.keys(): + if input_name in need_transed_var_names: + var_dtype = ( + inputs[input_name][0].dtype + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].dtype + ) + if var_dtype != common_dtype: + inputs[input_name] = ( + [inputs[input_name][0].astype(common_dtype)] + if isinstance(inputs[input_name], (list, tuple)) + else inputs[input_name].astype(common_dtype) + ) + def append_op(self, *args, **kwargs): """ Appends a new Operator according to the giving arguments. @@ -4394,6 +4439,7 @@ def append_op(self, *args, **kwargs): op_type = kwargs.get("type", None) if in_dygraph_mode(): attrs = kwargs.get("attrs", {}) + inputs = kwargs.get("inputs", {}) warnings.warn( "Op `%s` is executed through `append_op` under the dynamic mode, " "the corresponding API implementation needs to be upgraded to " @@ -4409,6 +4455,8 @@ def append_op(self, *args, **kwargs): attrs=attrs, ) + self._type_promotion_for_inputs(op_type, inputs) + # record ops in tracer rather than blocks # # TODO(minqiyang): add op stop_gradient support in static graph mode too. @@ -4416,7 +4464,7 @@ def append_op(self, *args, **kwargs): _dygraph_tracer().trace_op( op_type, - kwargs.get("inputs", {}), + inputs, kwargs.get("outputs", {}), attrs if attrs else {}, kwargs.get("stop_gradient", False), @@ -4440,9 +4488,11 @@ def pass_stop_gradient(ins, outs): if isinstance(var, Variable): var.stop_gradient = True - op_desc = self.desc.append_op() inputs = kwargs.get("inputs", None) outputs = kwargs.get("outputs", None) + + self._type_promotion_for_inputs(op_type, inputs) + op_desc = self.desc.append_op() # NOTE(Aurelius84): In case of @to_static, all Tensor(s) should # be converted into Variable(s) with same name and block location. # This is ONE and ONLY logic of type transformation of dy2static. diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 1f070882758b9..a995424dd730c 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -32,6 +32,15 @@ compare_ops = ['__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__'] +SUPPORT_PROMOTION_OPS = [ + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", +] + EXPRESSION_MAP = { "__add__": "A + B", "__radd__": "A += B", @@ -519,10 +528,18 @@ def __impl__(self, other_var): current_block(self), value=other_var, dtype=lhs_dtype ) - # 3. unify right var type to left var + # 3. type promotion rhs_dtype = safe_get_dtype(other_var) + if lhs_dtype != rhs_dtype: - other_var = astype(other_var, lhs_dtype) + # NOTE(zoooo0820): Currently, we still keep the old illogical + # logic for compatibility reasons + if method_name in SUPPORT_PROMOTION_OPS: + if not core.need_type_promotion(lhs_dtype, rhs_dtype): + other_var = astype(other_var, lhs_dtype) + else: + other_var = astype(other_var, lhs_dtype) + if reverse: tmp = self self = other_var @@ -537,6 +554,13 @@ def __impl__(self, other_var): # NOTE(zhiqiu): the output of compare operator should be bool. if method_name in compare_ops: out = create_new_tmp_var(current_block(self), dtype="bool") + elif method_name in SUPPORT_PROMOTION_OPS: + out = create_new_tmp_var( + current_block(self), + dtype=core.get_promote_dtype( + op_type, self.dtype, other_var.dtype + ), + ) else: out = create_new_tmp_var( current_block(self), dtype=safe_get_dtype(self) diff --git a/python/paddle/base/type_promotion.py b/python/paddle/base/type_promotion.py deleted file mode 100644 index fd36f554028f8..0000000000000 --- a/python/paddle/base/type_promotion.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from paddle.framework import dtype - -u1 = dtype.uint8 -i1 = dtype.int8 -i2 = dtype.int16 -i4 = dtype.int32 -i8 = dtype.int64 -f2 = dtype.float16 -f4 = dtype.float32 -f8 = dtype.float64 -c4 = dtype.complex64 -c8 = dtype.complex128 -b1 = dtype.bool -bf = dtype.bfloat16 - - -Number = { - dtype.uint8: 0, - dtype.int8: 1, - dtype.int16: 2, - dtype.int32: 3, - dtype.int64: 4, - dtype.float16: 5, - dtype.float32: 6, - dtype.float64: 7, - dtype.complex64: 8, - dtype.complex128: 9, - dtype.bool: 10, - dtype.bfloat16: 11, -} - -promoteTypesLookup = [ - [u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf], - [i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf], - [i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf], - [i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf], - [i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf], - [f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4], - [f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4], - [f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8], - [c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4], - [c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8], - [u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf], - [bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf], -] - - -def get_result_dtype(x_dtype, y_dtype): - if x_dtype == y_dtype: - return x_dtype - else: - try: - return promoteTypesLookup[Number[x_dtype]][Number[y_dtype]] - except: - raise TypeError( - f"got unsupport dtype for type promotion: {x_dtype} and {y_dtype}." - ) diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index a4e3f76d7ee8b..c47bfe8e5d1d5 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -51,5 +51,181 @@ def test_operator(self): self.div_operator() +def create_test_case(baseclass, ldtype, rdtype, expected_out_dtype=None): + class TestPromotion(baseclass): + def set_dtype(self): + self.ldtype = ldtype + self.rdtype = rdtype + self.expected_out_dtype = expected_out_dtype + + cls_name = f"{baseclass.__name__}Between{ldtype}And{rdtype}" + TestPromotion.__name__ = cls_name + globals()[cls_name] = TestPromotion + + +class TestOperatorOverloadAddInStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.set_dtype() + self.exe = paddle.static.Executor() + + def set_dtype(self): + self.ldtype = 'float32' + self.rdtype = 'float64' + self.expected_out_dtype = 'float64' + + def generate_test_value(self): + self.l_value = (paddle.randn((4, 3, 2)) * 10).astype(self.ldtype) + self.r_value = (paddle.randn((4, 3, 2)) * 10).astype(self.rdtype) + + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value + self.r_value + out_reverse = self.r_value + self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + def test_dtype_is_expected(self): + res = self.run_api() + self.assertEqual(res[0].dtype.__str__(), self.expected_out_dtype) + self.assertEqual(res[1].dtype.__str__(), self.expected_out_dtype) + + +create_test_case( + TestOperatorOverloadAddInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadAddInStatic, 'float16', 'float64', 'float64' +) + +create_test_case( + TestOperatorOverloadAddInStatic, 'float32', 'float64', 'float64' +) + + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestOperatorOverloadAddInStatic, 'bfloat16', 'float64', 'float64' + ) + + +class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic): + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value - self.r_value + out_reverse = self.r_value - self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case( + TestOperatorOverloadSubInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadSubInStatic, 'float16', 'float64', 'float64' +) + +create_test_case( + TestOperatorOverloadSubInStatic, 'float32', 'float64', 'float64' +) + + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestOperatorOverloadSubInStatic, 'bfloat16', 'float64', 'float64' + ) + + +class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic): + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value * self.r_value + out_reverse = self.r_value * self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32' +) +create_test_case( + TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64' +) + +create_test_case( + TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64' +) + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64' + ) + + +class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic): + def set_dtype(self): + self.ldtype = 'float32' + self.rdtype = 'float64' + self.expected_out_dtype = 'bool' + + def run_api(self): + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + self.generate_test_value() + + out = self.l_value > self.r_value + out_reverse = self.r_value > self.l_value + + res = self.exe.run(prog, fetch_list=[out, out_reverse]) + return res + + +create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool') +create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool') + +create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool') + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool' + ) + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool' + ) + create_test_case( + TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool' + ) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 89328610e9272..e08d4b0b4763e 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from op_test import OpTest, convert_float_to_uint16 +from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float import paddle from paddle import base @@ -318,6 +318,61 @@ def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape): expect = np.where(cond_data, x_data, y_data) np.testing.assert_array_equal(out[0], expect) + def __test_where_with_type_promotion( + self, x_dtype, y_dtype, expeced_dtype=None + ): + paddle.enable_static() + main_program = paddle.static.Program() + shape = [3, 10] + with paddle.static.program_guard(main_program): + cond = paddle.static.data(name='cond', shape=[3, 10], dtype='bool') + x = paddle.static.data(name='x', shape=shape, dtype=x_dtype) + y = paddle.static.data(name='y', shape=shape, dtype=y_dtype) + cond_data_tmp = np.random.random(size=shape).astype('float32') + cond_data = cond_data_tmp < 0.3 + + if x_dtype != 'bfloat16': + x_data = np.random.random(size=shape).astype(x_dtype) + else: + x_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + if y_dtype != 'bfloat16': + y_data = np.random.random(size=shape).astype(y_dtype) + else: + y_data = convert_float_to_uint16( + np.random.random(size=shape).astype('float32') + ) + result = paddle.where(condition=cond, x=x, y=y) + for use_cuda in [False, True]: + if use_cuda and (not base.core.is_compiled_with_cuda()): + return + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + exe = base.Executor(place) + out = exe.run( + paddle.static.default_main_program(), + feed={'cond': cond_data, 'x': x_data, 'y': y_data}, + fetch_list=[result], + ) + if x_dtype == 'bfloat16' or y_dtype == 'bfloat16': + x_data_convert = ( + convert_uint16_to_float(x_data) + if x_dtype == 'bfloat16' + else x_data + ) + y_data_convert = ( + convert_uint16_to_float(y_data) + if y_dtype == 'bfloat16' + else y_data + ) + expect = np.where(cond_data, x_data_convert, y_data_convert) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype.__str__(), expeced_dtype) + else: + expect = np.where(cond_data, x_data, y_data) + np.testing.assert_array_equal(out[0], expect) + self.assertEqual(out[0].dtype, expect.dtype) + @test_with_pir_api def test_static_api_broadcast_1(self): cond_shape = [2, 4] @@ -374,6 +429,63 @@ def test_static_api_broadcast_8(self): b_shape = [2, 2, 1] self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape) + def test_static_api_type_promotion_fp16_fp32(self): + x_dtype = 'float16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp16_fp64(self): + x_dtype = 'float16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + def test_static_api_type_promotion_fp32_fp64(self): + x_dtype = 'float32' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype) + self.__test_where_with_type_promotion(y_dtype, x_dtype) + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp16(self): + x_dtype = 'bfloat16' + y_dtype = 'float16' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp32(self): + x_dtype = 'bfloat16' + y_dtype = 'float32' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float32') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float32') + + @unittest.skipIf( + not ( + paddle.is_compiled_with_cuda() + and paddle.base.core.supports_bfloat16() + ), + "bf16 is not supported in current device", + ) + def test_static_api_type_promotion_bf16_fp64(self): + x_dtype = 'bfloat16' + y_dtype = 'float64' + self.__test_where_with_type_promotion(x_dtype, y_dtype, 'float64') + self.__test_where_with_type_promotion(y_dtype, x_dtype, 'float64') + class TestWhereDygraphAPI(unittest.TestCase): def test_api(self):