diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc index aa18f8cd4acb8..cfea756cf02d5 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc @@ -70,7 +70,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } // Type promotion Logic - if (phi::NeedTypePromotion(x.dtype(), y.dtype())) { + if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) { VLOG(5) << "got different data type, run type promotion automatically."; LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion " @@ -247,6 +247,22 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT VLOG(5) << " No AMP for multiply__ad_func because it is a inplace or cast api. "; + + // Type promotion Logic + if (phi::NeedTypePromotion("multiply_", x.dtype(), y.dtype())) { + VLOG(5) << "got different data type, run type promotion automatically."; + LOG_FIRST_N(WARNING, 1) + << "got different data type, run type promotion " + "automatically, this may cause data type been changed."; + auto op_name = phi::TransToFluidOpName("multiply_"); + auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype()); + + x = egr::PromoteCastInplace("x", x, promotion_type); + auto new_y = egr::PromoteCast("y", y, promotion_type); + + return multiply__ad_func(x, new_y); + } + // Layout autotune if (egr::Controller::Instance().UseLayoutAutoTune()) { @@ -424,7 +440,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } // Type promotion Logic - if (phi::NeedTypePromotion(x.dtype(), y.dtype())) { + if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) { VLOG(5) << "got different data type, run type promotion automatically."; LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion " diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 50a9ffc29e39c..d7379ffb4e444 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -108,6 +108,24 @@ "atan2": ["x", "y"], } +type_promote_inplace_white_list = { + "add_": ["x", "y"], + "subtract_": ["x", "y"], + "divide_": ["x", "y"], + "floor_divide_": ["x", "y"], + "where_": ["x", "y"], + "equal_": ["x", "y"], + "not_equal_": ["x", "y"], + "less_than_": ["x", "y"], + "less_equal_": ["x", "y"], + "greater_than_": ["x", "y"], + "greater_equal_": ["x", "y"], + "logical_and_": ["x", "y"], + "logical_or_": ["x", "y"], + "logical_xor_": ["x", "y"], + "remainder_": ["x", "y"], +} + # dict of special api that forward api's output will affect backward api's output # backward api's output usually affected by backward api's input @@ -558,13 +576,13 @@ class {} : public egr::GradNodeBase {{ }} """ -TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({x}.dtype(), {y}.dtype())) {{ +TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({op_func_name}, {x}.dtype(), {y}.dtype())) {{ VLOG(5) << "got different data type, run type promotion automatically."; LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion automatically, this may cause data type been changed."; {op_name} auto promotion_type = phi::GetPromoteDtype(op_name, {x}.dtype(), {y}.dtype()); - auto new_{x} = egr::PromoteCast("{x}", {x}, promotion_type); + {x_cast} auto new_{y} = egr::PromoteCast("{y}", {y}, promotion_type); {return_value} @@ -1532,6 +1550,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): type_promote_inputs_call_list[pos] = f"new_{name}" else: type_promote_inputs_call_list[pos] = f"{name}" + elif forward_api_name in type_promote_inplace_white_list: + if name in type_promote_inplace_white_list[forward_api_name]: + if ( + is_inplaced + and forward_inplace_map + and name in forward_inplace_map + ): + type_promote_inputs_call_list[pos] = f"{name}" + else: + type_promote_inputs_call_list[pos] = f"new_{name}" + else: + type_promote_inputs_call_list[pos] = f"{name}" if IsPlainTensorType(ttype): if is_optional: if ( @@ -1868,6 +1898,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): # Forward type promotion logic if forward_api_name in type_promote_white_list: # only support two inputs + op_func_name = f"\"{forward_api_name}\"" x = type_promote_white_list[forward_api_name][0] y = type_promote_white_list[forward_api_name][1] type_promote_inputs_call_args_str = ", ".join( @@ -1875,9 +1906,35 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): ) type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});" + x_cast = f"auto new_{x} = egr::PromoteCast(\"{x}\", {x}, promotion_type);" + + type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format( + op_func_name=op_func_name, + x=x, + y=y, + x_cast=x_cast, + op_name=kernel_trans2_op_name_str, + return_value=type_promote_call_list, + ) + elif forward_api_name in type_promote_inplace_white_list: + # only support two inputs + op_func_name = f"\"{forward_api_name}\"" + x = type_promote_inplace_white_list[forward_api_name][0] + y = type_promote_inplace_white_list[forward_api_name][1] + type_promote_inputs_call_args_str = ", ".join( + type_promote_inputs_call_list + ) + type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});" + + x_cast = ( + f"{x} = egr::PromoteCastInplace(\"{x}\", {x}, promotion_type);" + ) + type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format( + op_func_name=op_func_name, x=x, y=y, + x_cast=x_cast, op_name=kernel_trans2_op_name_str, return_value=type_promote_call_list, ) diff --git a/paddle/fluid/eager/type_promotion_utils.h b/paddle/fluid/eager/type_promotion_utils.h index 3ef732bac78bf..7ab9965cd15c4 100644 --- a/paddle/fluid/eager/type_promotion_utils.h +++ b/paddle/fluid/eager/type_promotion_utils.h @@ -30,4 +30,15 @@ inline paddle::Tensor PromoteCast(const std::string& input_name, } } +inline paddle::Tensor PromoteCastInplace(const std::string& input_name, + paddle::Tensor& input, // NOLINT + const phi::DataType& dst_dtype, + bool trace_backward = true) { + if (input.dtype() != dst_dtype) { + return paddle::experimental::cast_(input, dst_dtype); + } else { + return input; + } +} + } // namespace egr diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 02d87fe02e00d..271aebaae7e49 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -924,9 +924,11 @@ PYBIND11_MODULE(libpaddle, m) { m.def("set_num_threads", &platform::SetNumThreads); m.def("need_type_promotion", - [](framework::proto::VarType::Type type_x, + [](const std::string &op_name, + framework::proto::VarType::Type type_x, framework::proto::VarType::Type type_y) { - return phi::NeedTypePromotion(framework::TransToPhiDataType(type_x), + return phi::NeedTypePromotion(op_name, + framework::TransToPhiDataType(type_x), framework::TransToPhiDataType(type_y)); }); m.def("get_promote_dtype", diff --git a/paddle/phi/common/type_promotion.h b/paddle/phi/common/type_promotion.h index 7bc50e90e7a4d..e8d8af1221c0b 100644 --- a/paddle/phi/common/type_promotion.h +++ b/paddle/phi/common/type_promotion.h @@ -133,10 +133,20 @@ inline phi::DataType GetPromoteDtype(const std::string& op_name, return phi::promoteTypes(x, y); } -inline bool NeedTypePromotion(const DataType x, const DataType y) { +inline bool NeedTypePromotion(const std::string& op_name, + const DataType x, + const DataType y) { // Tensor + Tensor type promotion only support calculations between // floating-point numbers and between complex and real numbers. if (x != y) { +// TODO(Xi Zhao): we got special case for add now, should remove it in furture. +#ifdef PADDLE_WITH_CUDA + if (op_name == "add" && x == DataType::FLOAT32 && + (y == phi::DataType::BFLOAT16 || y == phi::DataType::FLOAT16)) { + return false; + } +#endif + if ((is_support_float(x) && is_support_float(y)) || (is_support_complex(x) || is_support_complex(y))) { return true; diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 823e4c760f3ea..aa4738d194ff9 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -8304,7 +8304,9 @@ def process_type_promotion(program): all_input_name_need_cast.append(input_arg_name) # only support promote between float - if len(all_dtypes) == 2 and core.need_type_promotion(*all_dtypes): + if len(all_dtypes) == 2 and core.need_type_promotion( + op.type, *all_dtypes + ): common_dtype = core.get_promote_dtype(op.type, *all_dtypes) for input_name_need_cast in all_input_name_need_cast: var_name = op.block._var_recursive(input_name_need_cast) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index 40b5659b067d3..241f395e8a518 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -656,7 +656,9 @@ def __impl__(self, other_var): self = astype(self, rhs_dtype) else: other_var = astype(other_var, lhs_dtype) - elif core.need_type_promotion(lhs_dtype, rhs_dtype): + elif core.need_type_promotion( + op_type, lhs_dtype, rhs_dtype + ): # only report warning here, real promotion deal in Executor warnings.warn( f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted" diff --git a/test/legacy_test/test_tensor_type_promotion.py b/test/legacy_test/test_tensor_type_promotion.py index 155f78c78bdf8..8e4e425babb1e 100644 --- a/test/legacy_test/test_tensor_type_promotion.py +++ b/test/legacy_test/test_tensor_type_promotion.py @@ -229,6 +229,83 @@ def run_api(self): create_test_case(TestAPIAddInDygraph, 'complex128', 'float64', 'complex128') +class TestAPIAddInplaceInDygraph(TestOperatorOverloadAddInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.add_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.add_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPIAddInplaceInDygraph, 'float16', 'float32', 'float32') +create_test_case(TestAPIAddInplaceInDygraph, 'float16', 'float64', 'float64') + +create_test_case(TestAPIAddInplaceInDygraph, 'float32', 'float64', 'float64') + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestAPIAddInplaceInDygraph, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestAPIAddInplaceInDygraph, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestAPIAddInplaceInDygraph, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestAPIAddInplaceInDygraph, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestAPIAddInplaceInDygraph, 'bfloat16', 'complex128', 'complex128' + ) + +create_test_case(TestAPIAddInplaceInDygraph, 'complex64', 'bool', 'complex64') +create_test_case(TestAPIAddInplaceInDygraph, 'complex64', 'int8', 'complex64') +create_test_case(TestAPIAddInplaceInDygraph, 'complex64', 'uint8', 'complex64') +create_test_case(TestAPIAddInplaceInDygraph, 'complex64', 'int16', 'complex64') +create_test_case(TestAPIAddInplaceInDygraph, 'complex64', 'int32', 'complex64') +create_test_case(TestAPIAddInplaceInDygraph, 'complex64', 'int64', 'complex64') +create_test_case( + TestAPIAddInplaceInDygraph, 'complex64', 'float16', 'complex64' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex64', 'float32', 'complex64' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex64', 'float64', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex64', 'complex128', 'complex128' +) + +create_test_case(TestAPIAddInplaceInDygraph, 'complex128', 'bool', 'complex128') +create_test_case(TestAPIAddInplaceInDygraph, 'complex128', 'int8', 'complex128') +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'uint8', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'int16', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'int32', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'int64', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'float16', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'float32', 'complex128' +) +create_test_case( + TestAPIAddInplaceInDygraph, 'complex128', 'float64', 'complex128' +) + + class TestOperatorOverloadSubInDygraph(TestOperatorOverloadAddInDygraph): def run_api(self): self.generate_test_value() @@ -373,6 +450,83 @@ def run_api(self): create_test_case(TestAPISubInDygraph, 'complex128', 'float64', 'complex128') +class TestAPISubInplaceInDygraph(TestOperatorOverloadAddInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.subtract_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.subtract_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPISubInplaceInDygraph, 'float16', 'float32', 'float32') +create_test_case(TestAPISubInplaceInDygraph, 'float16', 'float64', 'float64') + +create_test_case(TestAPISubInplaceInDygraph, 'float32', 'float64', 'float64') + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestAPISubInplaceInDygraph, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestAPISubInplaceInDygraph, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestAPISubInplaceInDygraph, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestAPISubInplaceInDygraph, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestAPISubInplaceInDygraph, 'bfloat16', 'complex128', 'complex128' + ) + +create_test_case(TestAPISubInplaceInDygraph, 'complex64', 'bool', 'complex64') +create_test_case(TestAPISubInplaceInDygraph, 'complex64', 'int8', 'complex64') +create_test_case(TestAPISubInplaceInDygraph, 'complex64', 'uint8', 'complex64') +create_test_case(TestAPISubInplaceInDygraph, 'complex64', 'int16', 'complex64') +create_test_case(TestAPISubInplaceInDygraph, 'complex64', 'int32', 'complex64') +create_test_case(TestAPISubInplaceInDygraph, 'complex64', 'int64', 'complex64') +create_test_case( + TestAPISubInplaceInDygraph, 'complex64', 'float16', 'complex64' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex64', 'float32', 'complex64' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex64', 'float64', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex64', 'complex128', 'complex128' +) + +create_test_case(TestAPISubInplaceInDygraph, 'complex128', 'bool', 'complex128') +create_test_case(TestAPISubInplaceInDygraph, 'complex128', 'int8', 'complex128') +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'uint8', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'int16', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'int32', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'int64', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'float16', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'float32', 'complex128' +) +create_test_case( + TestAPISubInplaceInDygraph, 'complex128', 'float64', 'complex128' +) + + class TestOperatorOverloadMulInDygraph(TestOperatorOverloadAddInDygraph): def run_api(self): self.generate_test_value() @@ -517,6 +671,83 @@ def run_api(self): create_test_case(TestAPIMulInDygraph, 'complex128', 'float64', 'complex128') +class TestAPIMulInplaceInDygraph(TestOperatorOverloadAddInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.multiply_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.multiply_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPIMulInplaceInDygraph, 'float16', 'float32', 'float32') +create_test_case(TestAPIMulInplaceInDygraph, 'float16', 'float64', 'float64') + +create_test_case(TestAPIMulInplaceInDygraph, 'float32', 'float64', 'float64') + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestAPIMulInplaceInDygraph, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestAPIMulInplaceInDygraph, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestAPIMulInplaceInDygraph, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestAPIMulInplaceInDygraph, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestAPIMulInplaceInDygraph, 'bfloat16', 'complex128', 'complex128' + ) + +create_test_case(TestAPIMulInplaceInDygraph, 'complex64', 'bool', 'complex64') +create_test_case(TestAPIMulInplaceInDygraph, 'complex64', 'int8', 'complex64') +create_test_case(TestAPIMulInplaceInDygraph, 'complex64', 'uint8', 'complex64') +create_test_case(TestAPIMulInplaceInDygraph, 'complex64', 'int16', 'complex64') +create_test_case(TestAPIMulInplaceInDygraph, 'complex64', 'int32', 'complex64') +create_test_case(TestAPIMulInplaceInDygraph, 'complex64', 'int64', 'complex64') +create_test_case( + TestAPIMulInplaceInDygraph, 'complex64', 'float16', 'complex64' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex64', 'float32', 'complex64' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex64', 'float64', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex64', 'complex128', 'complex128' +) + +create_test_case(TestAPIMulInplaceInDygraph, 'complex128', 'bool', 'complex128') +create_test_case(TestAPIMulInplaceInDygraph, 'complex128', 'int8', 'complex128') +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'uint8', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'int16', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'int32', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'int64', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'float16', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'float32', 'complex128' +) +create_test_case( + TestAPIMulInplaceInDygraph, 'complex128', 'float64', 'complex128' +) + + class TestOperatorOverloadDivInDygraph(TestOperatorOverloadAddInDygraph): def run_api(self): self.generate_test_value() @@ -661,6 +892,83 @@ def run_api(self): create_test_case(TestAPIDivInDygraph, 'complex128', 'float64', 'complex128') +class TestAPIDivInplaceInDygraph(TestOperatorOverloadAddInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.divide_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.divide_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPIDivInplaceInDygraph, 'float16', 'float32', 'float32') +create_test_case(TestAPIDivInplaceInDygraph, 'float16', 'float64', 'float64') + +create_test_case(TestAPIDivInplaceInDygraph, 'float32', 'float64', 'float64') + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case( + TestAPIDivInplaceInDygraph, 'bfloat16', 'float16', 'float32' + ) + create_test_case( + TestAPIDivInplaceInDygraph, 'bfloat16', 'float32', 'float32' + ) + create_test_case( + TestAPIDivInplaceInDygraph, 'bfloat16', 'float64', 'float64' + ) + create_test_case( + TestAPIDivInplaceInDygraph, 'bfloat16', 'complex64', 'complex64' + ) + create_test_case( + TestAPIDivInplaceInDygraph, 'bfloat16', 'complex128', 'complex128' + ) + +create_test_case(TestAPIDivInplaceInDygraph, 'complex64', 'bool', 'complex64') +create_test_case(TestAPIDivInplaceInDygraph, 'complex64', 'int8', 'complex64') +create_test_case(TestAPIDivInplaceInDygraph, 'complex64', 'uint8', 'complex64') +create_test_case(TestAPIDivInplaceInDygraph, 'complex64', 'int16', 'complex64') +create_test_case(TestAPIDivInplaceInDygraph, 'complex64', 'int32', 'complex64') +create_test_case(TestAPIDivInplaceInDygraph, 'complex64', 'int64', 'complex64') +create_test_case( + TestAPIDivInplaceInDygraph, 'complex64', 'float16', 'complex64' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex64', 'float32', 'complex64' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex64', 'float64', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex64', 'complex128', 'complex128' +) + +create_test_case(TestAPIDivInplaceInDygraph, 'complex128', 'bool', 'complex128') +create_test_case(TestAPIDivInplaceInDygraph, 'complex128', 'int8', 'complex128') +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'uint8', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'int16', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'int32', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'int64', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'float16', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'float32', 'complex128' +) +create_test_case( + TestAPIDivInplaceInDygraph, 'complex128', 'float64', 'complex128' +) + + class TestOperatorOverloadPowInDygraph(TestOperatorOverloadAddInDygraph): def run_api(self): self.generate_test_value() @@ -742,26 +1050,48 @@ def run_api(self): return out, out_reverse +create_test_case(TestAPIFloorDivInDygraph, 'float16', 'float32', 'float32') +create_test_case(TestAPIFloorDivInDygraph, 'float16', 'float64', 'float64') + +create_test_case(TestAPIFloorDivInDygraph, 'float32', 'float64', 'float64') + +if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): + create_test_case(TestAPIFloorDivInDygraph, 'bfloat16', 'float16', 'float32') + create_test_case(TestAPIFloorDivInDygraph, 'bfloat16', 'float32', 'float32') + create_test_case(TestAPIFloorDivInDygraph, 'bfloat16', 'float64', 'float64') + + +class TestAPIFloorDivInplaceInDygraph(TestOperatorOverloadAddInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.floor_divide_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.floor_divide_(self.l_value) + + return out, out_reverse + + create_test_case( - TestOperatorOverloadFloorDivInDygraph, 'float16', 'float32', 'float32' + TestAPIFloorDivInplaceInDygraph, 'float16', 'float32', 'float32' ) create_test_case( - TestOperatorOverloadFloorDivInDygraph, 'float16', 'float64', 'float64' + TestAPIFloorDivInplaceInDygraph, 'float16', 'float64', 'float64' ) create_test_case( - TestOperatorOverloadFloorDivInDygraph, 'float32', 'float64', 'float64' + TestAPIFloorDivInplaceInDygraph, 'float32', 'float64', 'float64' ) if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16(): create_test_case( - TestOperatorOverloadFloorDivInDygraph, 'bfloat16', 'float16', 'float32' + TestAPIFloorDivInplaceInDygraph, 'bfloat16', 'float16', 'float32' ) create_test_case( - TestOperatorOverloadFloorDivInDygraph, 'bfloat16', 'float32', 'float32' + TestAPIFloorDivInplaceInDygraph, 'bfloat16', 'float32', 'float32' ) create_test_case( - TestOperatorOverloadFloorDivInDygraph, 'bfloat16', 'float64', 'float64' + TestAPIFloorDivInplaceInDygraph, 'bfloat16', 'float64', 'float64' ) @@ -803,6 +1133,23 @@ def run_api(self): create_test_case(TestAPIModInDygraph, 'float32', 'float64', 'float64') +class TestAPIModInplaceInDygraph(TestOperatorOverloadAddInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.mod_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.mod_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPIModInplaceInDygraph, 'float16', 'float32', 'float32') +create_test_case(TestAPIModInplaceInDygraph, 'float16', 'float64', 'float64') + +create_test_case(TestAPIModInplaceInDygraph, 'float32', 'float64', 'float64') + + class TestOperatorOverloadEqualInDygraph(unittest.TestCase): def setUp(self): paddle.disable_static() @@ -863,6 +1210,23 @@ def run_api(self): create_test_case(TestAPIEqualInDygraph, 'float32', 'float64', 'bool') +class TestAPIEqualInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.equal_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.equal_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPIEqualInplaceInDygraph, 'float16', 'float32', 'bool') +create_test_case(TestAPIEqualInplaceInDygraph, 'float16', 'float64', 'bool') + +create_test_case(TestAPIEqualInplaceInDygraph, 'float32', 'float64', 'bool') + + class TestOperatorOverloadNotEqualInDygraph(TestOperatorOverloadEqualInDygraph): def run_api(self): self.generate_test_value() @@ -901,6 +1265,23 @@ def run_api(self): create_test_case(TestAPINotEqualInDygraph, 'float32', 'float64', 'bool') +class TestAPINotEqualInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.not_equal_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.not_equal_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPINotEqualInplaceInDygraph, 'float16', 'float32', 'bool') +create_test_case(TestAPINotEqualInplaceInDygraph, 'float16', 'float64', 'bool') + +create_test_case(TestAPINotEqualInplaceInDygraph, 'float32', 'float64', 'bool') + + class TestOperatorOverloadLessThanInDygraph(TestOperatorOverloadEqualInDygraph): def run_api(self): self.generate_test_value() @@ -939,6 +1320,23 @@ def run_api(self): create_test_case(TestAPILessThanInDygraph, 'float32', 'float64', 'bool') +class TestAPILessThanInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.less_than_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.less_than_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPILessThanInplaceInDygraph, 'float16', 'float32', 'bool') +create_test_case(TestAPILessThanInplaceInDygraph, 'float16', 'float64', 'bool') + +create_test_case(TestAPILessThanInplaceInDygraph, 'float32', 'float64', 'bool') + + class TestOperatorOverloadLessEqualInDygraph( TestOperatorOverloadEqualInDygraph ): @@ -979,6 +1377,23 @@ def run_api(self): create_test_case(TestAPILessEqualInDygraph, 'float32', 'float64', 'bool') +class TestAPILessEqualInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.less_equal_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.less_equal_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPILessEqualInplaceInDygraph, 'float16', 'float32', 'bool') +create_test_case(TestAPILessEqualInplaceInDygraph, 'float16', 'float64', 'bool') + +create_test_case(TestAPILessEqualInplaceInDygraph, 'float32', 'float64', 'bool') + + class TestOperatorOverloadGreaterThanInDygraph( TestOperatorOverloadEqualInDygraph ): @@ -1019,6 +1434,29 @@ def run_api(self): create_test_case(TestAPIGreaterThanInDygraph, 'float32', 'float64', 'bool') +class TestAPIGreaterThanInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.greater_than_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.greater_than_(self.l_value) + + return out, out_reverse + + +create_test_case( + TestAPIGreaterThanInplaceInDygraph, 'float16', 'float32', 'bool' +) +create_test_case( + TestAPIGreaterThanInplaceInDygraph, 'float16', 'float64', 'bool' +) + +create_test_case( + TestAPIGreaterThanInplaceInDygraph, 'float32', 'float64', 'bool' +) + + class TestOperatorOverloadGreaterEqualInDygraph( TestOperatorOverloadEqualInDygraph ): @@ -1059,6 +1497,29 @@ def run_api(self): create_test_case(TestAPIGreaterEqualInDygraph, 'float32', 'float64', 'bool') +class TestAPIGreaterEqualInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.greater_equal_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.greater_equal_(self.l_value) + + return out, out_reverse + + +create_test_case( + TestAPIGreaterEqualInplaceInDygraph, 'float16', 'float32', 'bool' +) +create_test_case( + TestAPIGreaterEqualInplaceInDygraph, 'float16', 'float64', 'bool' +) + +create_test_case( + TestAPIGreaterEqualInplaceInDygraph, 'float32', 'float64', 'bool' +) + + class TestAPILogicalAndInDygraph(TestOperatorOverloadEqualInDygraph): def run_api(self): self.generate_test_value() @@ -1094,6 +1555,78 @@ def run_api(self): create_test_case(TestAPILogicalAndInDygraph, 'complex128', 'float64', 'bool') +class TestAPILogicalAndInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.logical_and_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.logical_and_(self.l_value) + + return out, out_reverse + + +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'float16', 'float32', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'float16', 'float64', 'bool' +) + +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'float32', 'float64', 'bool' +) + +create_test_case(TestAPILogicalAndInplaceInDygraph, 'complex64', 'bool', 'bool') +create_test_case(TestAPILogicalAndInplaceInDygraph, 'complex64', 'int8', 'bool') +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'int16', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'int32', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'int64', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'float16', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'float32', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'float64', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex64', 'complex128', 'bool' +) + +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'bool', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'int8', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'int16', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'int32', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'int64', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'float16', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'float32', 'bool' +) +create_test_case( + TestAPILogicalAndInplaceInDygraph, 'complex128', 'float64', 'bool' +) + + class TestAPILogicalOrInDygraph(TestOperatorOverloadEqualInDygraph): def run_api(self): self.generate_test_value() @@ -1129,6 +1662,62 @@ def run_api(self): create_test_case(TestAPILogicalOrInDygraph, 'complex128', 'float64', 'bool') +class TestAPILogicalOrInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.logical_or_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.logical_or_(self.l_value) + + return out, out_reverse + + +create_test_case(TestAPILogicalOrInplaceInDygraph, 'float16', 'float32', 'bool') +create_test_case(TestAPILogicalOrInplaceInDygraph, 'float16', 'float64', 'bool') + +create_test_case(TestAPILogicalOrInplaceInDygraph, 'float32', 'float64', 'bool') + +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex64', 'bool', 'bool') +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex64', 'int8', 'bool') +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex64', 'int16', 'bool') +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex64', 'int32', 'bool') +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex64', 'int64', 'bool') +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex64', 'float16', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex64', 'float32', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex64', 'float64', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex64', 'complex128', 'bool' +) + +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex128', 'bool', 'bool') +create_test_case(TestAPILogicalOrInplaceInDygraph, 'complex128', 'int8', 'bool') +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex128', 'int16', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex128', 'int32', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex128', 'int64', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex128', 'float16', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex128', 'float32', 'bool' +) +create_test_case( + TestAPILogicalOrInplaceInDygraph, 'complex128', 'float64', 'bool' +) + + class TestAPILogicalXorInDygraph(TestOperatorOverloadEqualInDygraph): def run_api(self): self.generate_test_value() @@ -1164,6 +1753,78 @@ def run_api(self): create_test_case(TestAPILogicalXorInDygraph, 'complex128', 'float64', 'bool') +class TestAPILogicalXorInplaceInDygraph(TestOperatorOverloadEqualInDygraph): + def run_api(self): + self.generate_test_value() + out = self.l_value.logical_xor_(self.r_value) + + self.generate_test_value() + out_reverse = self.r_value.logical_xor_(self.l_value) + + return out, out_reverse + + +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'float16', 'float32', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'float16', 'float64', 'bool' +) + +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'float32', 'float64', 'bool' +) + +create_test_case(TestAPILogicalXorInplaceInDygraph, 'complex64', 'bool', 'bool') +create_test_case(TestAPILogicalXorInplaceInDygraph, 'complex64', 'int8', 'bool') +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'int16', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'int32', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'int64', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'float16', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'float32', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'float64', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex64', 'complex128', 'bool' +) + +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'bool', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'int8', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'int16', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'int32', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'int64', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'float16', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'float32', 'bool' +) +create_test_case( + TestAPILogicalXorInplaceInDygraph, 'complex128', 'float64', 'bool' +) + + class TestAPIFmaxInDygraph(TestOperatorOverloadAddInDygraph): def run_api(self): self.generate_test_value()