Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically bind dtype #62508

Merged
merged 26 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
10501fc
🧹 chore: update Python and C++ files for eager execution properties a…
zrr1999 Mar 6, 2024
df210be
🐛 fix: add missing imports for bind_vartype and bind_datatype
zrr1999 Mar 6, 2024
c8085cd
🧹 chore: remove redundant NOLINT comment in tensor_name__doc definition
zrr1999 Mar 6, 2024
6653f91
🧹 chore(dtype): move imports
zrr1999 Mar 6, 2024
93b50c5
Empty-Commit
gouzil Mar 7, 2024
6003932
🐛 fix: fix dtype check in monkey_patch_math_tensor.py [optional body]…
zrr1999 Mar 8, 2024
4c47219
m
zrr1999 Mar 8, 2024
8ed0e9b
fix(opcode_translator): fix dtype abbreviation for TensorVariable
zrr1999 Mar 8, 2024
5f00576
chore(python/paddle/jit/sot/opcode_translator/executor/variables): re…
zrr1999 Mar 8, 2024
0af04f5
fix(dygraph): fix variable dtype comparison issue
zrr1999 Mar 10, 2024
d6196a7
fix(paddle): fix dtype comparison logic in framework.py and dtype.py
zrr1999 Mar 11, 2024
a2ff143
fix(pir_dy2static): fix dtype assignment in ParametersRecorder
zrr1999 Mar 11, 2024
8bc4d35
chore: add g_data_type_pytype to PyObject_CheckLongOrToLong function
zrr1999 Mar 11, 2024
3107638
fix DTYPE_ATTR when use env directly
SigureMo Mar 11, 2024
808d3c7
sim both vartype and datatype as TensorDtypeVariable
SigureMo Mar 11, 2024
8c08793
remove dtype convert in `MetaTensor.from_tensor`, in `use_pir_api` mo…
SigureMo Mar 11, 2024
79221d1
use `use_pir_api` in convert_np_dtype_to_dtype_ to avoid skip numpy d…
SigureMo Mar 11, 2024
80de817
dont use use_pir_api in convert_np_dtype_to_dtype_
SigureMo Mar 11, 2024
97ff3fa
refine dtype convert in tensor_patch_methods
SigureMo Mar 11, 2024
bb32a9c
use paddle.<dtype>, fix test_grad
SigureMo Mar 11, 2024
385265f
use VarType in coalesce_tensor, fix test_lstm
SigureMo Mar 11, 2024
2b11b41
move comment
SigureMo Mar 11, 2024
cf2b1de
convert dtype to VarType before init Tensor, fix error in Seg model
SigureMo Mar 11, 2024
c34ac64
adapt randint, pass CINN CI
SigureMo Mar 12, 2024
d89daf2
try to use `use_pir_api` in `convert_np_dtype_to_dtype_` again
SigureMo Mar 12, 2024
7d68baf
fix `_create_tensor`
SigureMo Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions paddle/fluid/pybind/eager_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ limitations under the License. */

#pragma GCC diagnostic ignored "-Wwrite-strings"

COMMON_DECLARE_bool(enable_pir_api);

namespace paddle {
namespace pybind {

Expand Down Expand Up @@ -847,25 +849,47 @@ Tensor's data type.
)DOC");
PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) {
EAGER_TRY
if (!self->tensor.defined()) {
// be same to old dygraph
return ToPyObject(framework::proto::VarType::FP32);
}
if (egr::IsVariableCompatTensor(self->tensor)) {
auto* var_tensor = static_cast<const egr::VariableCompatTensor*>(
self->tensor.impl().get());
if (var_tensor->IsType<paddle::framework::Vocab>()) {
return ToPyObject(framework::proto::VarType::RAW);
} else if (var_tensor->IsType<paddle::framework::Strings>()) {
return ToPyObject(framework::proto::VarType::STRING);
if (FLAGS_enable_pir_api) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面这个分支里的逻辑,后续可以考虑抽离成一个函数,避免if-for的多层嵌套

if (!self->tensor.defined()) {
// be same to old dygraph
return ToPyObject(phi::DataType::FLOAT32);
}
if (egr::IsVariableCompatTensor(self->tensor)) {
auto* var_tensor = static_cast<const egr::VariableCompatTensor*>(
self->tensor.impl().get());
if (var_tensor->IsType<paddle::framework::Vocab>()) {
return ToPyObject(phi::DataType::UNDEFINED);
} else if (var_tensor->IsType<paddle::framework::Strings>()) {
return ToPyObject(phi::DataType::PSTRING);
Comment on lines +860 to +863
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR的修改对当前的代码执行应该是没有影响,不过当前pir下的DataType里并没有RAW和STRING类型,需要讨论确认在动态图模式下,此修改对相关模型执行的影响

} else {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"VariableCompatTensor only support get shape from Vocab or "
"Strings."));
}
} else {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"VariableCompatTensor only support get shape from Vocab or "
"Strings."));
return ToPyObject(self->tensor.type());
}
} else {
return ToPyObject(
paddle::framework::TransToProtoVarType(self->tensor.type()));
if (!self->tensor.defined()) {
// be same to old dygraph
return ToPyObject(framework::proto::VarType::FP32);
}
if (egr::IsVariableCompatTensor(self->tensor)) {
auto* var_tensor = static_cast<const egr::VariableCompatTensor*>(
self->tensor.impl().get());
if (var_tensor->IsType<paddle::framework::Vocab>()) {
return ToPyObject(framework::proto::VarType::RAW);
} else if (var_tensor->IsType<paddle::framework::Strings>()) {
return ToPyObject(framework::proto::VarType::STRING);
} else {
PADDLE_THROW(paddle::platform::errors::Unavailable(
"VariableCompatTensor only support get shape from Vocab or "
"Strings."));
}
} else {
return ToPyObject(
paddle::framework::TransToProtoVarType(self->tensor.type()));
}
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}
Expand Down
14 changes: 11 additions & 3 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,12 @@ PyObject* ToPyObject(const phi::DenseTensor* value) {
return obj.ptr();
}

PyObject* ToPyObject(const phi::DataType& dtype) {
auto obj = ::pybind11::cast(dtype);
obj.inc_ref();
return obj.ptr();
}

PyObject* ToPyObject(const pir::Value& value) {
auto obj = ::pybind11::cast(value);
obj.inc_ref();
Expand Down Expand Up @@ -2410,9 +2416,11 @@ paddle::DataType CastPyArg2DataType(PyObject* obj,
if (obj == Py_None) {
return phi::DataType::UNDEFINED;
}

framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos);
return framework::TransToPhiDataType(type);
if (PyObject_TypeCheck(obj, g_vartype_pytype)) {
framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos);
return framework::TransToPhiDataType(type);
}
return CastPyArg2DataTypeDirectly(obj, op_type, arg_pos);
}

paddle::Tensor PyTensorHook::operator()(const paddle::Tensor& var) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ PyObject* ToPyObject(const phi::distributed::Placements& value);
PyObject* ToPyObject(const phi::SelectedRows* value);
PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype);
PyObject* ToPyObject(const paddle::framework::proto::VarType& type);
PyObject* ToPyObject(const phi::DataType& type);
PyObject* ToPyObject(const void* value);
PyObject* ToPyObject(const std::unordered_map<int, int>& value);
PyObject* ToPyObject(
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class OpAttrTypeMap {
};

extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_data_type_pytype;
extern PyTypeObject* g_blockdesc_pytype;
extern PyTypeObject* p_tensor_type;

Expand All @@ -72,6 +73,7 @@ bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }
bool PyObject_CheckLongOrToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT
PyObject_TypeCheck(*obj, g_data_type_pytype) || // NOLINT
(PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true;
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def astype(self, dtype):
>>> print("new tensor's dtype is: {}".format(new_tensor.dtype))
new tensor's dtype is: paddle.float32
"""
if not isinstance(dtype, core.VarDesc.VarType):
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
return _C_ops.cast(self, dtype)

Expand Down
9 changes: 7 additions & 2 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_PADDLE_DTYPE_2_NUMPY_DTYPE,
convert_uint16_to_float,
)
from paddle.base.framework import paddle_type_to_proto_type
from paddle.profiler.utils import in_profiler_mode
from paddle.utils import deprecated

Expand Down Expand Up @@ -219,10 +220,14 @@ def set_value(self, value):
else:
dtype = convert_np_dtype_to_dtype_(value.dtype)

self_dtype = self.dtype
if isinstance(self_dtype, core.DataType):
self_dtype = paddle_type_to_proto_type[self_dtype]

assert (
self.dtype == dtype
self_dtype == dtype
), "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
self.name, self.dtype, dtype
self.name, self_dtype, dtype
)

# NOTE(wuweilong): self could be Tensor, the subsequent behavior are defined in different files
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,9 @@ def _create_tensor(
**kwargs,
):
if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType):
if isinstance(dtype, core.DataType):
dtype = paddle_type_to_proto_type[dtype]
elif not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

eager_tensor = core.eager.Tensor(
Expand Down
138 changes: 117 additions & 21 deletions python/paddle/framework/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,130 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from ..base import framework
from ..base.core import (
DataType,
VarDesc,
finfo as core_finfo,
iinfo as core_iinfo,
)
from ..base.data_feeder import _NUMPY_DTYPE_2_PADDLE_DTYPE

dtype = VarDesc.VarType
dtype.__qualname__ = "dtype"
dtype.__module__ = "paddle"

uint8 = VarDesc.VarType.UINT8
int8 = VarDesc.VarType.INT8
int16 = VarDesc.VarType.INT16
int32 = VarDesc.VarType.INT32
int64 = VarDesc.VarType.INT64

float32 = VarDesc.VarType.FP32
float64 = VarDesc.VarType.FP64
float16 = VarDesc.VarType.FP16
bfloat16 = VarDesc.VarType.BF16

complex64 = VarDesc.VarType.COMPLEX64
complex128 = VarDesc.VarType.COMPLEX128

bool = VarDesc.VarType.BOOL
def bind_vartype():
global dtype
global uint8
global int8
global int16
global int32
global int64
global float32
global float64
global float16
global bfloat16
global complex64
global complex128
global bool

dtype = VarDesc.VarType
dtype.__qualname__ = "dtype"
dtype.__module__ = "paddle"

uint8 = VarDesc.VarType.UINT8
int8 = VarDesc.VarType.INT8
int16 = VarDesc.VarType.INT16
int32 = VarDesc.VarType.INT32
int64 = VarDesc.VarType.INT64

float32 = VarDesc.VarType.FP32
float64 = VarDesc.VarType.FP64
float16 = VarDesc.VarType.FP16
bfloat16 = VarDesc.VarType.BF16

complex64 = VarDesc.VarType.COMPLEX64
complex128 = VarDesc.VarType.COMPLEX128

bool = VarDesc.VarType.BOOL

paddle.dtype = dtype
paddle.uint8 = uint8
paddle.int8 = int8
paddle.int16 = int16
paddle.int32 = int32
paddle.int64 = int64

paddle.float32 = float32
paddle.float64 = float64
paddle.float16 = float16
paddle.bfloat16 = bfloat16

paddle.complex64 = complex64
paddle.complex128 = complex128
paddle.bool = bool


def bind_datatype():
global dtype
global uint8
global int8
global int16
global int32
global int64
global float32
global float64
global float16
global bfloat16
global complex64
global complex128
global bool

dtype = DataType
dtype.__qualname__ = "dtype"
dtype.__module__ = "paddle"

uint8 = DataType.UINT8
int8 = DataType.INT8
int16 = DataType.INT16
int32 = DataType.INT32
int64 = DataType.INT64

float32 = DataType.FLOAT32
float64 = DataType.FLOAT64
float16 = DataType.FLOAT16
bfloat16 = DataType.BFLOAT16

complex64 = DataType.COMPLEX64
complex128 = DataType.COMPLEX128

bool = DataType.BOOL

paddle.dtype = dtype
paddle.uint8 = uint8
paddle.int8 = int8
paddle.int16 = int16
paddle.int32 = int32
paddle.int64 = int64

paddle.float32 = float32
paddle.float64 = float64
paddle.float16 = float16
paddle.bfloat16 = bfloat16

paddle.complex64 = complex64
paddle.complex128 = complex128
paddle.bool = bool


enable_pir_api = framework.get_flags("FLAGS_enable_pir_api")[
"FLAGS_enable_pir_api"
]

if enable_pir_api:
bind_datatype()
else:
bind_vartype()


def iinfo(dtype):
Expand Down Expand Up @@ -130,9 +228,7 @@ def finfo(dtype):
"""
import paddle

if paddle.base.framework.in_pir_mode() and isinstance(
dtype, paddle.pir.core.DataType
):
if isinstance(dtype, paddle.pir.core.DataType):
dtype = paddle.base.framework.paddle_type_to_proto_type[dtype]
elif dtype in _NUMPY_DTYPE_2_PADDLE_DTYPE:
dtype = _NUMPY_DTYPE_2_PADDLE_DTYPE[dtype]
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/jit/pir_dy2static/parameter_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
from paddle.autograd.backward_utils import ValueDict
from paddle.framework import core

from ..dy2static.program_translator import _program_hash, synchronized

Expand All @@ -37,8 +38,11 @@ def get(self, program, tensor):
mappings = self.tensor2value[key]
if id(tensor) not in mappings:
non_used_initializer = paddle.nn.initializer.Constant(0.0)
dtype = tensor.dtype
if isinstance(dtype, core.VarDesc.VarType):
vartype_to_datatype[dtype]
value = create_parameter(
dtype=vartype_to_datatype[tensor.dtype],
dtype=dtype,
shape=tensor.shape,
type=tensor.type,
initializer=non_used_initializer,
Expand Down
Loading