diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index fa926618bdf8db..ba857e9cdbfbdb 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -35,6 +35,8 @@ limitations under the License. */ #pragma GCC diagnostic ignored "-Wwrite-strings" +COMMON_DECLARE_bool(enable_pir_api); + namespace paddle { namespace pybind { @@ -847,25 +849,47 @@ Tensor's data type. )DOC"); PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) { EAGER_TRY - if (!self->tensor.defined()) { - // be same to old dygraph - return ToPyObject(framework::proto::VarType::FP32); - } - if (egr::IsVariableCompatTensor(self->tensor)) { - auto* var_tensor = static_cast( - self->tensor.impl().get()); - if (var_tensor->IsType()) { - return ToPyObject(framework::proto::VarType::RAW); - } else if (var_tensor->IsType()) { - return ToPyObject(framework::proto::VarType::STRING); + if (FLAGS_enable_pir_api) { + if (!self->tensor.defined()) { + // be same to old dygraph + return ToPyObject(phi::DataType::FLOAT32); + } + if (egr::IsVariableCompatTensor(self->tensor)) { + auto* var_tensor = static_cast( + self->tensor.impl().get()); + if (var_tensor->IsType()) { + return ToPyObject(phi::DataType::UNDEFINED); + } else if (var_tensor->IsType()) { + return ToPyObject(phi::DataType::PSTRING); + } else { + PADDLE_THROW(paddle::platform::errors::Unavailable( + "VariableCompatTensor only support get shape from Vocab or " + "Strings.")); + } } else { - PADDLE_THROW(paddle::platform::errors::Unavailable( - "VariableCompatTensor only support get shape from Vocab or " - "Strings.")); + return ToPyObject(self->tensor.type()); } } else { - return ToPyObject( - paddle::framework::TransToProtoVarType(self->tensor.type())); + if (!self->tensor.defined()) { + // be same to old dygraph + return ToPyObject(framework::proto::VarType::FP32); + } + if (egr::IsVariableCompatTensor(self->tensor)) { + auto* var_tensor = static_cast( + self->tensor.impl().get()); + if (var_tensor->IsType()) { + return ToPyObject(framework::proto::VarType::RAW); + } else if (var_tensor->IsType()) { + return ToPyObject(framework::proto::VarType::STRING); + } else { + PADDLE_THROW(paddle::platform::errors::Unavailable( + "VariableCompatTensor only support get shape from Vocab or " + "Strings.")); + } + } else { + return ToPyObject( + paddle::framework::TransToProtoVarType(self->tensor.type())); + } } EAGER_CATCH_AND_THROW_RETURN_NULL } diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 851e498bac8b3a..aba7c99662bbee 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1077,6 +1077,12 @@ PyObject* ToPyObject(const phi::DenseTensor* value) { return obj.ptr(); } +PyObject* ToPyObject(const phi::DataType& dtype) { + auto obj = ::pybind11::cast(dtype); + obj.inc_ref(); + return obj.ptr(); +} + PyObject* ToPyObject(const pir::Value& value) { auto obj = ::pybind11::cast(value); obj.inc_ref(); @@ -2410,9 +2416,11 @@ paddle::DataType CastPyArg2DataType(PyObject* obj, if (obj == Py_None) { return phi::DataType::UNDEFINED; } - - framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos); - return framework::TransToPhiDataType(type); + if (PyObject_TypeCheck(obj, g_vartype_pytype)) { + framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos); + return framework::TransToPhiDataType(type); + } + return CastPyArg2DataTypeDirectly(obj, op_type, arg_pos); } paddle::Tensor PyTensorHook::operator()(const paddle::Tensor& var) { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 2511ddb57dbb5b..e56741aa90776a 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -148,6 +148,7 @@ PyObject* ToPyObject(const phi::distributed::Placements& value); PyObject* ToPyObject(const phi::SelectedRows* value); PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype); PyObject* ToPyObject(const paddle::framework::proto::VarType& type); +PyObject* ToPyObject(const phi::DataType& type); PyObject* ToPyObject(const void* value); PyObject* ToPyObject(const std::unordered_map& value); PyObject* ToPyObject( diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 5d7977ce5c4422..f8f1424ded2432 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -64,6 +64,7 @@ class OpAttrTypeMap { }; extern PyTypeObject* g_vartype_pytype; +extern PyTypeObject* g_data_type_pytype; extern PyTypeObject* g_blockdesc_pytype; extern PyTypeObject* p_tensor_type; @@ -72,6 +73,7 @@ bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); } bool PyObject_CheckLongOrToLong(PyObject** obj) { if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) || PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT + PyObject_TypeCheck(*obj, g_data_type_pytype) || // NOLINT (PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 3f7b7a40ffa461..916dedea28418e 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -87,7 +87,7 @@ def astype(self, dtype): >>> print("new tensor's dtype is: {}".format(new_tensor.dtype)) new tensor's dtype is: paddle.float32 """ - if not isinstance(dtype, core.VarDesc.VarType): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) return _C_ops.cast(self, dtype) diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index 275ab3a232d96c..e5e6fda5bc5969 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -188,13 +188,12 @@ def set_value(self, value): ... linear.weight.set_value(custom_weight) # change existing weight ... out = linear(t) # call with different weight """ - base_tensor = core.eager.Tensor assert isinstance( - value, (np.ndarray, base_tensor, dict, str) + value, (np.ndarray, paddle.Tensor, dict, str) ), "Variable set_value function, arguments type only support Variable, numpy, Tensor, dict, string." if self.is_dist(): assert isinstance( - value, (np.ndarray, base_tensor) + value, (np.ndarray, paddle.Tensor) ), "For set_value function of dist tensor, arguments type only support numpy or Tensor." if isinstance(value, (dict, str)): @@ -214,8 +213,10 @@ def set_value(self, value): self.name, self.shape, value.shape ) - if isinstance(value, base_tensor): + if isinstance(value, paddle.Tensor): dtype = value.dtype + elif paddle.framework.use_pir_api(): + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(value.dtype) else: dtype = convert_np_dtype_to_dtype_(value.dtype) diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index a306004bca62a8..1d3bbd28873c2e 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -1262,7 +1262,7 @@ def convert_np_dtype_to_dtype_(np_dtype): core.VarDesc.VarType / core.DataType : The data type in Paddle. """ - if in_pir_mode(): + if use_pir_api(): return pir.core.convert_np_dtype_to_dtype_(np_dtype) # Convert the data type string to numpy data type. @@ -1350,11 +1350,15 @@ def _create_tensor( **kwargs, ): if dtype is not None: - if not isinstance(dtype, core.VarDesc.VarType): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) + if isinstance(dtype, core.DataType): + dtype = paddle_type_to_proto_type[dtype] + else: + dtype = core.VarDesc.VarType.FP32 eager_tensor = core.eager.Tensor( - dtype if dtype else core.VarDesc.VarType.FP32, + dtype, list(shape) if shape else [], name, type if type else core.VarDesc.VarType.LOD_TENSOR, @@ -7588,8 +7592,12 @@ def __init__(self, shape, dtype, **kwargs): ) if dtype is not None: - if not isinstance(dtype, core.VarDesc.VarType): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) + if isinstance(dtype, core.DataType): + dtype = paddle_type_to_proto_type[dtype] + else: + dtype = core.VarDesc.VarType.FP32 name = kwargs.get('name', unique_name.generate('_eager_param_base')) @@ -7597,7 +7605,7 @@ def __init__(self, shape, dtype, **kwargs): shape = shape.numpy() super().__init__( - dtype if dtype else core.VarDesc.VarType.FP32, + dtype, list(shape) if shape else [], name, core.VarDesc.VarType.LOD_TENSOR, diff --git a/python/paddle/framework/dtype.py b/python/paddle/framework/dtype.py index 1183d80d035308..aeb93681730dfa 100644 --- a/python/paddle/framework/dtype.py +++ b/python/paddle/framework/dtype.py @@ -12,32 +12,130 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle + +from ..base import framework from ..base.core import ( + DataType, VarDesc, finfo as core_finfo, iinfo as core_iinfo, ) from ..base.data_feeder import _NUMPY_DTYPE_2_PADDLE_DTYPE -dtype = VarDesc.VarType -dtype.__qualname__ = "dtype" -dtype.__module__ = "paddle" - -uint8 = VarDesc.VarType.UINT8 -int8 = VarDesc.VarType.INT8 -int16 = VarDesc.VarType.INT16 -int32 = VarDesc.VarType.INT32 -int64 = VarDesc.VarType.INT64 - -float32 = VarDesc.VarType.FP32 -float64 = VarDesc.VarType.FP64 -float16 = VarDesc.VarType.FP16 -bfloat16 = VarDesc.VarType.BF16 - -complex64 = VarDesc.VarType.COMPLEX64 -complex128 = VarDesc.VarType.COMPLEX128 -bool = VarDesc.VarType.BOOL +def bind_vartype(): + global dtype + global uint8 + global int8 + global int16 + global int32 + global int64 + global float32 + global float64 + global float16 + global bfloat16 + global complex64 + global complex128 + global bool + + dtype = VarDesc.VarType + dtype.__qualname__ = "dtype" + dtype.__module__ = "paddle" + + uint8 = VarDesc.VarType.UINT8 + int8 = VarDesc.VarType.INT8 + int16 = VarDesc.VarType.INT16 + int32 = VarDesc.VarType.INT32 + int64 = VarDesc.VarType.INT64 + + float32 = VarDesc.VarType.FP32 + float64 = VarDesc.VarType.FP64 + float16 = VarDesc.VarType.FP16 + bfloat16 = VarDesc.VarType.BF16 + + complex64 = VarDesc.VarType.COMPLEX64 + complex128 = VarDesc.VarType.COMPLEX128 + + bool = VarDesc.VarType.BOOL + + paddle.dtype = dtype + paddle.uint8 = uint8 + paddle.int8 = int8 + paddle.int16 = int16 + paddle.int32 = int32 + paddle.int64 = int64 + + paddle.float32 = float32 + paddle.float64 = float64 + paddle.float16 = float16 + paddle.bfloat16 = bfloat16 + + paddle.complex64 = complex64 + paddle.complex128 = complex128 + paddle.bool = bool + + +def bind_datatype(): + global dtype + global uint8 + global int8 + global int16 + global int32 + global int64 + global float32 + global float64 + global float16 + global bfloat16 + global complex64 + global complex128 + global bool + + dtype = DataType + dtype.__qualname__ = "dtype" + dtype.__module__ = "paddle" + + uint8 = DataType.UINT8 + int8 = DataType.INT8 + int16 = DataType.INT16 + int32 = DataType.INT32 + int64 = DataType.INT64 + + float32 = DataType.FLOAT32 + float64 = DataType.FLOAT64 + float16 = DataType.FLOAT16 + bfloat16 = DataType.BFLOAT16 + + complex64 = DataType.COMPLEX64 + complex128 = DataType.COMPLEX128 + + bool = DataType.BOOL + + paddle.dtype = dtype + paddle.uint8 = uint8 + paddle.int8 = int8 + paddle.int16 = int16 + paddle.int32 = int32 + paddle.int64 = int64 + + paddle.float32 = float32 + paddle.float64 = float64 + paddle.float16 = float16 + paddle.bfloat16 = bfloat16 + + paddle.complex64 = complex64 + paddle.complex128 = complex128 + paddle.bool = bool + + +enable_pir_api = framework.get_flags("FLAGS_enable_pir_api")[ + "FLAGS_enable_pir_api" +] + +if enable_pir_api: + bind_datatype() +else: + bind_vartype() def iinfo(dtype): @@ -130,9 +228,7 @@ def finfo(dtype): """ import paddle - if paddle.base.framework.in_pir_mode() and isinstance( - dtype, paddle.pir.core.DataType - ): + if isinstance(dtype, paddle.pir.core.DataType): dtype = paddle.base.framework.paddle_type_to_proto_type[dtype] elif dtype in _NUMPY_DTYPE_2_PADDLE_DTYPE: dtype = _NUMPY_DTYPE_2_PADDLE_DTYPE[dtype] diff --git a/python/paddle/jit/pir_dy2static/parameter_recorder.py b/python/paddle/jit/pir_dy2static/parameter_recorder.py index ef0440eaa981b7..646e810ffe3e47 100644 --- a/python/paddle/jit/pir_dy2static/parameter_recorder.py +++ b/python/paddle/jit/pir_dy2static/parameter_recorder.py @@ -14,6 +14,7 @@ import paddle from paddle.autograd.backward_utils import ValueDict +from paddle.framework import core from ..dy2static.program_translator import _program_hash, synchronized @@ -37,8 +38,11 @@ def get(self, program, tensor): mappings = self.tensor2value[key] if id(tensor) not in mappings: non_used_initializer = paddle.nn.initializer.Constant(0.0) + dtype = tensor.dtype + if isinstance(dtype, core.VarDesc.VarType): + vartype_to_datatype[dtype] value = create_parameter( - dtype=vartype_to_datatype[tensor.dtype], + dtype=dtype, shape=tensor.shape, type=tensor.type, initializer=non_used_initializer, diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py index 7f90468bdf4b09..7eebf39e008916 100644 --- a/python/paddle/jit/sot/infer_meta.py +++ b/python/paddle/jit/sot/infer_meta.py @@ -16,7 +16,6 @@ import paddle from paddle.amp.auto_cast import amp_state -from paddle.base import framework from paddle.base.data_feeder import convert_dtype from paddle.base.unique_name import ( UniqueNameGenerator, @@ -41,15 +40,20 @@ def __init__( @staticmethod def from_tensor(tensor): - # We always use float32 in simulation if AMP is enabled. if isinstance(tensor, paddle.pir.Value): name = "Value@NoName" - persistable = tensor.persistable - dtype = framework.paddle_type_to_proto_type[tensor.dtype] - else: + else: # For Tensor or Variable name = tensor.name - persistable = tensor.persistable - dtype = tensor.dtype + persistable = tensor.persistable + dtype = tensor.dtype + expected_dtype_class = ( + paddle.core.DataType + if paddle.framework.use_pir_api() + else paddle.core.VarDesc.VarType + ) + assert isinstance(dtype, expected_dtype_class) + + # We always use float32 in simulation if AMP is enabled. current_amp_state = amp_state() if ( dtype == paddle.float16 diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py index 7c3490aed0eb86..fe99525fe44a14 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -22,8 +22,7 @@ import numpy as np import paddle -from paddle.framework import use_pir_api -from paddle.pir.core import vartype_to_datatype +from paddle.framework import core from ....infer_meta import MetaInfo from ....symbolic.statement_ir import Symbol @@ -61,30 +60,30 @@ FP_DTYPE_ABBRS = { - paddle.bfloat16: 'bfloat16', - paddle.float64: 'float64', - paddle.float32: 'float32', - paddle.float16: 'float16', + core.DataType.BFLOAT16: "bfloat16", + core.DataType.FLOAT64: "float64", + core.DataType.FLOAT32: "float32", + core.DataType.FLOAT16: "float16", } CP_DTYPE_ABBRS = { - paddle.complex64: 'complex64', - paddle.complex128: 'complex128', + core.DataType.COMPLEX64: "complex64", + core.DataType.COMPLEX128: "complex128", } INT_DTYPE_ABBRS = { - paddle.int8: 'int8', - paddle.int16: 'int16', - paddle.int32: 'int32', - paddle.int64: 'int64', - paddle.uint8: 'uint8', + core.DataType.INT8: "int8", + core.DataType.INT16: "int16", + core.DataType.INT32: "int32", + core.DataType.INT64: "int64", + core.DataType.UINT8: "uint8", } DTYPE_ABBRS = { **FP_DTYPE_ABBRS, **CP_DTYPE_ABBRS, **INT_DTYPE_ABBRS, - paddle.bool: 'bool', + core.DataType.BOOL: "bool", } @@ -271,32 +270,14 @@ def make_stringify_guard(self) -> list[StringifyExpression]: return object_equal_stringify_guard(self) def get_py_value(self, allow_tensor=False): - if use_pir_api() and isinstance( - self.value, paddle.base.core.VarDesc.VarType - ): - return vartype_to_datatype[self.value] return super().get_py_value(allow_tensor) def get_py_type(self): - if use_pir_api() and isinstance( - self.value, paddle.base.core.VarDesc.VarType - ): - return paddle.pir.core.DataType return super().get_py_type() def _reconstruct(self, codegen: PyCodeGen): # dtype of paddle.Tensor is hashable, we can just load it as const var - if use_pir_api() and isinstance( - self.value, paddle.base.core.VarDesc.VarType - ): - assert ( - self.value in paddle.pir.core.vartype_to_datatype - ), f"Unknow dtype {self.value}" - codegen.gen_load_const( - paddle.pir.core.vartype_to_datatype[self.value] - ) - else: - codegen.gen_load_const(self.value) + codegen.gen_load_const(self.value) @property def main_info(self) -> dict[str, Any]: @@ -306,7 +287,9 @@ def main_info(self) -> dict[str, Any]: @VariableFactory.register_from_value() def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): - if isinstance(value, paddle.dtype): + if isinstance( + value, (paddle.core.VarDesc.VarType, paddle.core.DataType) + ): return TensorDtypeVariable(value, graph, tracker) @@ -410,15 +393,18 @@ def get_iter(self): @property def main_info(self) -> dict[str, Any]: + dtype = self.meta.dtype + if isinstance(dtype, paddle.core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] return { "shape": self.meta.shape, - "dtype": DTYPE_ABBRS[self.meta.dtype], + "dtype": DTYPE_ABBRS[dtype], "stop_gradient": self.meta.stop_gradient, "var_name": self.var_name, } def getitem(self, key): - return self.graph.call_tensor_method('__getitem__', self, key) + return self.graph.call_tensor_method("__getitem__", self, key) def setitem(self, key, value): self.graph.add_global_guarded_variable(value) @@ -502,16 +488,22 @@ def is_tensor(self): def is_complex(self): dtype = self.meta.dtype + if isinstance(dtype, paddle.core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] is_cp_dtype = dtype in CP_DTYPE_ABBRS return ConstantVariable(is_cp_dtype, self.graph, DummyTracker([self])) def is_integer(self): dtype = self.meta.dtype + if isinstance(dtype, paddle.core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] is_int_dtype = dtype in INT_DTYPE_ABBRS return ConstantVariable(is_int_dtype, self.graph, DummyTracker([self])) def is_floating_point(self): dtype = self.meta.dtype + if isinstance(dtype, paddle.core.VarDesc.VarType): + dtype = paddle.pir.core.vartype_to_datatype[dtype] is_fp_dtype = dtype in FP_DTYPE_ABBRS return ConstantVariable(is_fp_dtype, self.graph, DummyTracker([self])) diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index 8e3282e766fffa..0f551b1aa6c416 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -708,11 +708,11 @@ def _dygraph_clip(self, params_grads): ) if ( - sum_square.dtype == core.VarDesc.VarType.FP16 - or sum_square.dtype == core.VarDesc.VarType.BF16 + sum_square.dtype == paddle.float16 + or sum_square.dtype == paddle.bfloat16 ): sum_square_list_fp16.append(sum_square) - elif sum_square.dtype == core.VarDesc.VarType.FP32: + elif sum_square.dtype == paddle.float32: sum_square_list_fp32.append(sum_square) else: sum_square_list.append(sum_square) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index f8b3dfb7b515ea..de848b9e16cced 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -1555,6 +1555,11 @@ def flatten_parameters(self): ) if in_dynamic_mode(): with paddle.no_grad(): + dtype = params[0].dtype + if isinstance(dtype, core.DataType): + dtype = paddle.base.framework.paddle_type_to_proto_type[ + dtype + ] _legacy_C_ops.coalesce_tensor( self._all_weights, self._all_weights, @@ -1564,7 +1569,7 @@ def flatten_parameters(self): "use_align", False, "dtype", - params[0].dtype, + dtype, ) return # for static-graph, append coalesce_tensor into startup program diff --git a/python/paddle/pir_utils.py b/python/paddle/pir_utils.py index e52837889d71f0..9adf1d04710894 100644 --- a/python/paddle/pir_utils.py +++ b/python/paddle/pir_utils.py @@ -16,6 +16,7 @@ from functools import wraps import paddle +from paddle.framework.dtype import bind_datatype, bind_vartype class IrGuard: @@ -49,11 +50,13 @@ def __enter__(self): paddle.enable_static() paddle.framework.set_flags({"FLAGS_enable_pir_api": True}) paddle.base.framework.global_var._use_pir_api_ = True + bind_datatype() self._switch_to_pir() def __exit__(self, exc_type, exc_val, exc_tb): paddle.framework.set_flags({"FLAGS_enable_pir_api": False}) paddle.base.framework.global_var._use_pir_api_ = False + bind_vartype() self._switch_to_old_ir() if self.in_dygraph_outside: paddle.disable_static() diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index d2a0c46369fadf..496ec9965d0cfc 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -23,6 +23,7 @@ in_dynamic_mode, in_dynamic_or_pir_mode, in_pir_mode, + use_pir_api, ) from ..base.data_feeder import ( @@ -1100,9 +1101,9 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): low = 0 if dtype is None: dtype = core.VarDesc.VarType.INT64 - if in_pir_mode(): + if use_pir_api(): dtype = DataType.INT64 - elif not isinstance(dtype, core.VarDesc.VarType): + elif not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) if in_dynamic_mode(): diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index 1ee370b1745bfd..2523efee2b395a 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -249,7 +249,7 @@ def test_with_different_input(self): foo = paddle.jit.to_static(foo_func) - # [16, 10] + [10] (varbase) + # [16, 10] + [10] (Tensor) out_1 = foo(paddle.to_tensor(x_data), paddle.to_tensor(y_data)) np.testing.assert_allclose(x_data + y_data, out_1.numpy(), rtol=1e-05) self.assertTrue(len(foo.program_cache) == 1) diff --git a/test/legacy_test/test_std_layer.py b/test/legacy_test/test_std_layer.py index aed3e750402e5c..9c42e7aae3829b 100644 --- a/test/legacy_test/test_std_layer.py +++ b/test/legacy_test/test_std_layer.py @@ -116,6 +116,7 @@ def test_alias(self): class TestStdError(unittest.TestCase): def test_error(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data('X', [2, 3, 4], 'int32') self.assertRaises(TypeError, paddle.std, x)