diff --git a/rfcs/APIs/20220330_api_design_for_finfo.md b/rfcs/APIs/20220330_api_design_for_finfo.md index 5c78bbcad..da3606982 100644 --- a/rfcs/APIs/20220330_api_design_for_finfo.md +++ b/rfcs/APIs/20220330_api_design_for_finfo.md @@ -2,11 +2,11 @@ | API名称 | paddle.finfo | | ------------------------------------------------------------ | -------------------------------- | -| 提交作者 | 林旭(isLinXu) | -| 提交时间 | 2022-04-12 | -| 版本号 | V2.0 | +| 提交作者 | lisamhy,林旭(isLinXu) | +| 提交时间 | 2022-04-12 | +| 版本号 | V2.0 | | 依赖飞桨版本 | develop | -| 文件名 | 20220330_api-design_for_finfo.md | +| 文件名 | 20220330_api-design_for_finfo.md | # 一、概述 @@ -484,34 +484,112 @@ API设计为`paddle.finfo(dtype)`,根据选择计算方法(比如eps、max、m 通过设计实现与API对应的Class,并通过pybind将相应的成员函数绑定到python,从而实现该API。 -- `.h`头文件定义声明 +- `pybind.cc` finfo class 实现 ```cpp -namespace paddle { -namespace pybind { - void BindFinfoVarDsec(pybind11::module *m); - void BindIinfoVarDsec(pybind11::module *m); -} -} +struct finfo { + int64_t bits; + double eps; + double min; // lowest() + double max; + double tiny; + double smallest_normal; // min() + double resolution; + std::string dtype; + + explicit finfo(const framework::proto::VarType::Type &type) { + switch (type) { + case framework::proto::VarType::FP16: + eps = std::numeric_limits::epsilon(); + min = std::numeric_limits::lowest(); + max = std::numeric_limits::max(); + smallest_normal = std::numeric_limits::min(); + tiny = smallest_normal; + resolution = std::pow( + 10, -std::numeric_limits::digits10); + bits = 16; + dtype = "float16"; + break; + case framework::proto::VarType::FP32: + case framework::proto::VarType::COMPLEX64: + eps = std::numeric_limits::epsilon(); + min = std::numeric_limits::lowest(); + max = std::numeric_limits::max(); + smallest_normal = std::numeric_limits::min(); + tiny = smallest_normal; + resolution = std::pow(10, -std::numeric_limits::digits10); + bits = 32; + dtype = "float32"; + break; + case framework::proto::VarType::FP64: + case framework::proto::VarType::COMPLEX128: + eps = std::numeric_limits::epsilon(); + min = std::numeric_limits::lowest(); + max = std::numeric_limits::max(); + smallest_normal = std::numeric_limits::min(); + tiny = smallest_normal; + resolution = std::pow(10, -std::numeric_limits::digits10); + bits = 64; + dtype = "float64"; + break; + case framework::proto::VarType::BF16: + eps = std::numeric_limits::epsilon(); + min = std::numeric_limits::lowest(); + max = std::numeric_limits::max(); + smallest_normal = + std::numeric_limits::min(); + tiny = smallest_normal; + resolution = std::pow( + 10, -std::numeric_limits::digits10); + bits = 16; + dtype = "bfloat16"; + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "the argument of paddle.finfo can only be paddle.float32, " + "paddle.float64, paddle.float16, paddle.bfloat16" + "paddle.complex64, or paddle.complex128")); + break; + } + } +}; ``` -- `.cc`绑定实现设计 +- `pybind.cc` finfo 绑定实现 ```cpp - -void BindFInfoVarDsec(pybind11::module *m){ - pybind11::class_ finfo_var_desc(*m, "VarDesc", ""); - finfo_var_desc.def(pybind11::init()) - .def("bits", &pd::Tinfo::Bits) - .def("eps", &pd::Tinfo::Eps) - .def("min", &pd::Tinfo::Min) - .def("max", &pd::Tinfo::Max) - .def("tiny", &pd::Tinfo::Tiny) - .def("resolution", &pd::Tinfo::Resolution) -} + py::class_(m, "finfo") + .def(py::init()) + .def_readonly("min", &finfo::min) + .def_readonly("max", &finfo::max) + .def_readonly("bits", &finfo::bits) + .def_readonly("eps", &finfo::eps) + .def_readonly("resolution", &finfo::resolution) + .def_readonly("smallest_normal", &finfo::smallest_normal) + .def_readonly("tiny", &finfo::tiny) + .def_readonly("dtype", &finfo::dtype) + .def("__repr__", [](const finfo &a) { + std::ostringstream oss; + oss << "paddle.finfo(min=" << a.min; + oss << ", max=" << a.max; + oss << ", eps=" << a.eps; + oss << ", resolution=" << a.resolution; + oss << ", smallest_normal=" << a.smallest_normal; + oss << ", tiny=" << a.tiny; + oss << ", bits=" << a.bits; + oss << ", dtype=" << a.dtype << ")"; + return oss.str(); + }); ``` +- `dtype.py` python 暴露 finfo API +```python +from ..fluid.core import finfo as core_finfo + +def finfo(dtype): + return core_finfo(dtype) +``` 实现思路: @@ -519,56 +597,7 @@ void BindFInfoVarDsec(pybind11::module *m){ - 因此要实现该API,需要如上抽象出一个符合要求的Class,同时并声明定义类下的成员函数来分别实现功能 -- 通过类的成员函数分别来实现eps、min、max等函数,通过Pybind11来进行接口与参数的绑定 - - - -## API实现方案 - -在paddle/fluid/framework/Info.h与Info.cc下新增实现函数 -定义class为`Tinfo`(借鉴Torch的结构设计,将finfo与iinfo合并为一个类进行实现) - -```c -class Tinfo { -public: - int Bits(const at::ScalarType& type) - float Eps(const at::ScalarType& type) - float Min(const at::ScalarType& type) - float Max(const at::ScalarType& type) - float Tiny(const at::ScalarType& type) - float Resolution(const at::ScalarType& type) -} -``` - -`.cc`实现 - -```cpp -int Tinfo::Bits(const at::ScalarType& type){ - int bits = elementSize(self->type) * 8; - return THPUtils_packInt64(bits); -} - -float Tinfo::Eps(const at::ScalarType& type){ - return std::numeric_limits::type>::epsilon()); -} - -float Tinfo::Min(const at::ScalarType& type){ - return std::numeric_limits::type>::lowest()); -} - -float Tinfo::Max(const at::ScalarType& type){ - return std::numeric_limits::type>::max()); -} - -float Tinfo::Tiny(const at::ScalarType& type){ - return std::numeric_limits::type>::min()); -} - -float Tinfo::Resolution(const at::ScalarType& type){ - return std::numeric_limits::type>::resolution()); - -} -``` +- 通过类的成员函数分别来实现 eps、min、max、bits、resolution、tiny、smallest_normal、dtype 等函数,通过Pybind11来进行接口与参数的绑定