Skip to content

Commit

Permalink
add inplace API support, add special case can skip type promotion (ad…
Browse files Browse the repository at this point in the history
…d x=float32,y=float16/bfloat16).
  • Loading branch information
zxcd committed Apr 24, 2024
1 parent dc624f8 commit de8ac06
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
Expand Down Expand Up @@ -247,6 +247,22 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT

VLOG(5)
<< " No AMP for multiply__ad_func because it is a inplace or cast api. ";

// Type promotion Logic
if (phi::NeedTypePromotion("multiply_", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
"automatically, this may cause data type been changed.";
auto op_name = phi::TransToFluidOpName("multiply_");
auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype());

x = egr::PromoteCastInplace("x", x, promotion_type);
auto new_y = egr::PromoteCast("y", y, promotion_type);

return multiply__ad_func(x, new_y);
}

// Layout autotune

if (egr::Controller::Instance().UseLayoutAutoTune()) {
Expand Down Expand Up @@ -424,7 +440,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
Expand Down
61 changes: 59 additions & 2 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@
"atan2": ["x", "y"],
}

type_promote_inplace_white_list = {
"add_": ["x", "y"],
"subtract_": ["x", "y"],
"divide_": ["x", "y"],
"floor_divide_": ["x", "y"],
"where_": ["x", "y"],
"equal_": ["x", "y"],
"not_equal_": ["x", "y"],
"less_than_": ["x", "y"],
"less_equal_": ["x", "y"],
"greater_than_": ["x", "y"],
"greater_equal_": ["x", "y"],
"logical_and_": ["x", "y"],
"logical_or_": ["x", "y"],
"logical_xor_": ["x", "y"],
"remainder_": ["x", "y"],
}

# dict of special api that forward api's output will affect backward api's output
# backward api's output usually affected by backward api's input

Expand Down Expand Up @@ -558,13 +576,13 @@ class {} : public egr::GradNodeBase {{
}}
"""

TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({x}.dtype(), {y}.dtype())) {{
TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({op_func_name}, {x}.dtype(), {y}.dtype())) {{
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion automatically, this may cause data type been changed.";
{op_name}
auto promotion_type = phi::GetPromoteDtype(op_name, {x}.dtype(), {y}.dtype());
auto new_{x} = egr::PromoteCast("{x}", {x}, promotion_type);
{x_cast}
auto new_{y} = egr::PromoteCast("{y}", {y}, promotion_type);
{return_value}
Expand Down Expand Up @@ -1532,6 +1550,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
elif forward_api_name in type_promote_inplace_white_list:
if name in type_promote_inplace_white_list[forward_api_name]:
if (
is_inplaced
and forward_inplace_map
and name in forward_inplace_map
):
type_promote_inputs_call_list[pos] = f"{name}"
else:
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
if is_optional:
if (
Expand Down Expand Up @@ -1868,16 +1898,43 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
# Forward type promotion logic
if forward_api_name in type_promote_white_list:
# only support two inputs
op_func_name = f"\"{forward_api_name}\""
x = type_promote_white_list[forward_api_name][0]
y = type_promote_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

x_cast = f"auto new_{x} = egr::PromoteCast(\"{x}\", {x}, promotion_type);"

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
y=y,
x_cast=x_cast,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
elif forward_api_name in type_promote_inplace_white_list:
# only support two inputs
op_func_name = f"\"{forward_api_name}\""
x = type_promote_inplace_white_list[forward_api_name][0]
y = type_promote_inplace_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

x_cast = (
f"{x} = egr::PromoteCastInplace(\"{x}\", {x}, promotion_type);"
)

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
y=y,
x_cast=x_cast,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/eager/type_promotion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,15 @@ inline paddle::Tensor PromoteCast(const std::string& input_name,
}
}

inline paddle::Tensor PromoteCastInplace(const std::string& input_name,
paddle::Tensor& input, // NOLINT
const phi::DataType& dst_dtype,
bool trace_backward = true) {
if (input.dtype() != dst_dtype) {
return paddle::experimental::cast_(input, dst_dtype);
} else {
return input;
}
}

} // namespace egr
6 changes: 4 additions & 2 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,11 @@ PYBIND11_MODULE(libpaddle, m) {
m.def("set_num_threads", &platform::SetNumThreads);

m.def("need_type_promotion",
[](framework::proto::VarType::Type type_x,
[](const std::string &op_name,
framework::proto::VarType::Type type_x,
framework::proto::VarType::Type type_y) {
return phi::NeedTypePromotion(framework::TransToPhiDataType(type_x),
return phi::NeedTypePromotion(op_name,
framework::TransToPhiDataType(type_x),
framework::TransToPhiDataType(type_y));
});
m.def("get_promote_dtype",
Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/common/type_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,20 @@ inline phi::DataType GetPromoteDtype(const std::string& op_name,
return phi::promoteTypes(x, y);
}

inline bool NeedTypePromotion(const DataType x, const DataType y) {
inline bool NeedTypePromotion(const std::string& op_name,
const DataType x,
const DataType y) {
// Tensor + Tensor type promotion only support calculations between
// floating-point numbers and between complex and real numbers.
if (x != y) {
// TODO(Xi Zhao): we got special case for add now, should remove it in furture.
#ifdef PADDLE_WITH_CUDA
if (op_name == "add" && x == DataType::FLOAT32 &&
(y == phi::DataType::BFLOAT16 || y == phi::DataType::FLOAT16)) {
return false;
}
#endif

if ((is_support_float(x) && is_support_float(y)) ||
(is_support_complex(x) || is_support_complex(y))) {
return true;
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -8304,7 +8304,9 @@ def process_type_promotion(program):
all_input_name_need_cast.append(input_arg_name)

# only support promote between float
if len(all_dtypes) == 2 and core.need_type_promotion(*all_dtypes):
if len(all_dtypes) == 2 and core.need_type_promotion(
op.type, *all_dtypes
):
common_dtype = core.get_promote_dtype(op.type, *all_dtypes)
for input_name_need_cast in all_input_name_need_cast:
var_name = op.block._var_recursive(input_name_need_cast)
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,9 @@ def __impl__(self, other_var):
self = astype(self, rhs_dtype)
else:
other_var = astype(other_var, lhs_dtype)
elif core.need_type_promotion(lhs_dtype, rhs_dtype):
elif core.need_type_promotion(
op_type, lhs_dtype, rhs_dtype
):
# only report warning here, real promotion deal in Executor
warnings.warn(
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, the output will be auto-promoted"
Expand Down
Loading

0 comments on commit de8ac06

Please sign in to comment.