Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor type promotion for static mode #59586

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
39d3341
add type promotion table.
zxcd Nov 23, 2023
7616912
fix codestyle.
zxcd Nov 23, 2023
f7992a7
add python table.
zxcd Nov 23, 2023
70993d6
fix dtype.
zxcd Nov 24, 2023
164daec
remove useless note
zxcd Nov 24, 2023
863f139
fix static-check
zxcd Nov 27, 2023
7fddccf
Merge branch 'type_promotion_stage1_table_only' into tensor_type_prom…
zoooo0820 Nov 28, 2023
3cafffa
add eager T+T logic.
zxcd Nov 29, 2023
a1c649a
remove useless file.
zxcd Nov 29, 2023
4ce9034
remove useless line.
zxcd Nov 29, 2023
359a689
fix
zxcd Nov 29, 2023
5af5d8c
dtype promotion for operator overload in static mode
zoooo0820 Nov 29, 2023
4c66404
Merge branch 'develop' into tensor_type_promotion_for_static_mode
zoooo0820 Nov 29, 2023
a92bb8f
fix
zxcd Nov 29, 2023
1b069a2
only support float series
zoooo0820 Nov 29, 2023
83ec3e0
update
zxcd Nov 29, 2023
2636832
fix note.
zxcd Nov 29, 2023
7c4b08e
mv common logic to common dir.
zxcd Nov 30, 2023
3a996f8
fix
zxcd Nov 30, 2023
6784ab7
remove deal for int.
zxcd Nov 30, 2023
defb035
remove int.
zxcd Nov 30, 2023
49b4cf4
only for complie
zxcd Nov 30, 2023
0f0f7b1
fix median / cross_entropy_loss
zoooo0820 Nov 30, 2023
bfc51fd
keep old illogical logic for compatibility reasons
zoooo0820 Nov 30, 2023
64e3d1b
Merge branch 'tensor_type_promotion_for_static_mode' into tensor_type…
zoooo0820 Nov 30, 2023
8f060a6
pybind the type_promotion function; remove python function; remove fl…
zoooo0820 Dec 1, 2023
b893736
remove change for dygraph
zoooo0820 Dec 1, 2023
a06b28c
rename type_promotion_table.h -> data_type_promotion.h
zoooo0820 Dec 1, 2023
fbb4704
convert dtype in Block.append_op; support where op
zoooo0820 Dec 4, 2023
88aad1e
add warnings
zoooo0820 Dec 4, 2023
12a2bfb
only promote if needed
zoooo0820 Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 9 additions & 78 deletions paddle/fluid/eager/type_promotion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,90 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"

#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/phi/common/data_type_promotion.h"

namespace egr {

inline int DataTypeToNum(const phi::DataType& dtype) {
switch (dtype) {
case phi::DataType::UINT8:
return 0;
case phi::DataType::INT8:
return 1;
case phi::DataType::INT16:
return 2;
case phi::DataType::INT32:
return 3;
case phi::DataType::INT64:
return 4;
case phi::DataType::FLOAT16:
return 5;
case phi::DataType::FLOAT32:
return 6;
case phi::DataType::FLOAT64:
return 7;
case phi::DataType::COMPLEX64:
return 8;
case phi::DataType::COMPLEX128:
return 9;
case phi::DataType::BOOL:
return 10;
case phi::DataType::BFLOAT16:
return 11;
default:
PD_THROW("Invalid enum data type for type promote `", dtype, "`.");
}
}

static inline bool is_support_float(phi::DataType dtype) {
if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::FLOAT32 ||
dtype == phi::DataType::FLOAT64 || dtype == phi::DataType::BFLOAT16) {
return true;
} else {
return false;
}
}

static inline bool is_support_int(phi::DataType dtype) {
if (dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64) {
return true;
inline paddle::Tensor PromoteCast(const std::string& input_name,
const paddle::Tensor& input,
const phi::DataType& dst_dtype,
bool trace_backward = true) {
if (input.dtype() != dst_dtype) {
return Cast(input, dst_dtype, trace_backward);
} else {
return false;
return input;
}
}

inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) {
constexpr auto u1 = phi::DataType::UINT8;
constexpr auto i1 = phi::DataType::INT8;
constexpr auto i2 = phi::DataType::INT16;
constexpr auto i4 = phi::DataType::INT32;
constexpr auto i8 = phi::DataType::INT64;
constexpr auto f2 = phi::DataType::FLOAT16;
constexpr auto f4 = phi::DataType::FLOAT32;
constexpr auto f8 = phi::DataType::FLOAT64;
constexpr auto c4 = phi::DataType::COMPLEX64;
constexpr auto c8 = phi::DataType::COMPLEX128;
constexpr auto b1 = phi::DataType::BOOL;
constexpr auto bf = phi::DataType::BFLOAT16;

static constexpr phi::DataType _promoteTypesLookup[12][12] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf},
};

return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)];
}

} // namespace egr
17 changes: 17 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ limitations under the License. */
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/api/include/operants_manager.h"
#include "paddle/phi/api/include/tensor_operants.h"
#include "paddle/phi/common/data_type_promotion.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
Expand Down Expand Up @@ -883,6 +884,22 @@ PYBIND11_MODULE(libpaddle, m) {
&paddle::prim::PrimCommonUtils::SetTargetGradName);
m.def("set_num_threads", &platform::SetNumThreads);

m.def("need_type_promotion",
[](framework::proto::VarType::Type type_x,
framework::proto::VarType::Type type_y) {
return phi::NeedTypePromotion(framework::TransToPhiDataType(type_x),
framework::TransToPhiDataType(type_y));
});
m.def("get_promote_dtype",
[](const std::string &op_name,
framework::proto::VarType::Type type_x,
framework::proto::VarType::Type type_y) {
return framework::TransToProtoVarType(
phi::GetPromoteDtype(op_name,
framework::TransToPhiDataType(type_x),
framework::TransToPhiDataType(type_y)));
});

m.def("disable_signal_handler", &DisableSignalHandler);

m.def("clear_gradients",
Expand Down
115 changes: 115 additions & 0 deletions paddle/phi/common/data_type_promotion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "paddle/phi/common/data_type.h"
namespace phi {

inline int DataTypeToNum(const DataType& dtype) {
switch (dtype) {
case DataType::UINT8:
return 0;
case DataType::INT8:
return 1;
case DataType::INT16:
return 2;
case DataType::INT32:
return 3;
case DataType::INT64:
return 4;
case DataType::FLOAT16:
return 5;
case DataType::FLOAT32:
return 6;
case DataType::FLOAT64:
return 7;
case DataType::COMPLEX64:
return 8;
case DataType::COMPLEX128:
return 9;
case DataType::BOOL:
return 10;
case DataType::BFLOAT16:
return 11;
default:
PD_THROW("Invalid enum data type for type promote `", dtype, "`.");
}
}

inline static DataType promoteTypes(DataType x, DataType y) {
constexpr auto u1 = DataType::UINT8;
constexpr auto i1 = DataType::INT8;
constexpr auto i2 = DataType::INT16;
constexpr auto i4 = DataType::INT32;
constexpr auto i8 = DataType::INT64;
constexpr auto f2 = DataType::FLOAT16;
constexpr auto f4 = DataType::FLOAT32;
constexpr auto f8 = DataType::FLOAT64;
constexpr auto c4 = DataType::COMPLEX64;
constexpr auto c8 = DataType::COMPLEX128;
constexpr auto b1 = DataType::BOOL;
constexpr auto bf = DataType::BFLOAT16;

const int total_type_num = 12;

static constexpr DataType
_promoteTypesLookup[total_type_num][total_type_num] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf},
};
return _promoteTypesLookup[DataTypeToNum(x)][DataTypeToNum(y)];
}

static inline bool is_support_float(DataType dtype) {
if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 ||
dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) {
return true;
} else {
return false;
}
}

inline phi::DataType GetPromoteDtype(const std::string& op_name,
const DataType x,
const DataType y) {
// future will deal this by different rule
if (op_name == "greater_than") {
// bool logic
return DataType::BOOL;
} else {
return phi::promoteTypes(x, y);
}
}

inline bool NeedTypePromotion(const DataType x, const DataType y) {
// Tensor + Tensor only support type promotion for float type
if ((x != y) && is_support_float(x) && is_support_float(y)) {
return true;
} else {
return false;
}
}

} // namespace phi
1 change: 0 additions & 1 deletion python/paddle/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@
HeterXpuTrainer,
)
from .backward import append_backward
from . import type_promotion

Tensor = LoDTensor
enable_imperative = enable_dygraph
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ def to_list(s):
_set_prim_target_grad_name,
)

# type promotion
from .libpaddle import need_type_promotion, get_promote_dtype # noqa: F401

# isort: on
if sys.platform != 'win32':
from .libpaddle import ( # noqa: F401
Expand Down
54 changes: 52 additions & 2 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
_global_flags_ = core.globals()

# TODO(zoooo0820): unify this dict of dygraph and static at Pybind
SUPPORT_PROMOTION_OPS_AND_INPUTNAME = {
"elementwise_add": ['X', 'Y'],
"elementwise_sub": ['X', 'Y'],
"elementwise_mul": ['X', 'Y'],
"where": ['X', 'Y'],
}


def _global_flags():
return _global_flags_
Expand Down Expand Up @@ -4383,6 +4391,43 @@ def _is_inited_by(block, var):
param.stop_gradient = stop_gradient
return param

def _type_promotion_for_inputs(self, op_type, inputs):
need_transed_var_names = SUPPORT_PROMOTION_OPS_AND_INPUTNAME.get(
op_type, None
)
if need_transed_var_names is None:
return

all_dtypes = []
for input_name in inputs.keys():
if input_name in need_transed_var_names:
var_dtype = (
inputs[input_name][0].dtype
if isinstance(inputs[input_name], (list, tuple))
else inputs[input_name].dtype
)
all_dtypes.append(var_dtype)

if core.need_type_promotion(*all_dtypes):
common_dtype = core.get_promote_dtype(op_type, *all_dtypes)
warnings.warn(
f"The input dtypes of OP {op_type} are {all_dtypes}, the output will be auto-promoted to {common_dtype}"
)

for input_name in inputs.keys():
if input_name in need_transed_var_names:
var_dtype = (
inputs[input_name][0].dtype
if isinstance(inputs[input_name], (list, tuple))
else inputs[input_name].dtype
)
if var_dtype != common_dtype:
inputs[input_name] = (
[inputs[input_name][0].astype(common_dtype)]
if isinstance(inputs[input_name], (list, tuple))
else inputs[input_name].astype(common_dtype)
)

def append_op(self, *args, **kwargs):
"""
Appends a new Operator according to the giving arguments.
Expand All @@ -4394,6 +4439,7 @@ def append_op(self, *args, **kwargs):
op_type = kwargs.get("type", None)
if in_dygraph_mode():
attrs = kwargs.get("attrs", {})
inputs = kwargs.get("inputs", {})
warnings.warn(
"Op `%s` is executed through `append_op` under the dynamic mode, "
"the corresponding API implementation needs to be upgraded to "
Expand All @@ -4409,14 +4455,16 @@ def append_op(self, *args, **kwargs):
attrs=attrs,
)

self._type_promotion_for_inputs(op_type, inputs)

# record ops in tracer rather than blocks
#
# TODO(minqiyang): add op stop_gradient support in static graph mode too.
# currently, we only support stop_gradient in dygraph mode.

_dygraph_tracer().trace_op(
op_type,
kwargs.get("inputs", {}),
inputs,
kwargs.get("outputs", {}),
attrs if attrs else {},
kwargs.get("stop_gradient", False),
Expand All @@ -4440,9 +4488,11 @@ def pass_stop_gradient(ins, outs):
if isinstance(var, Variable):
var.stop_gradient = True

op_desc = self.desc.append_op()
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)

self._type_promotion_for_inputs(op_type, inputs)
op_desc = self.desc.append_op()
# NOTE(Aurelius84): In case of @to_static, all Tensor(s) should
# be converted into Variable(s) with same name and block location.
# This is ONE and ONLY logic of type transformation of dy2static.
Expand Down
Loading