diff --git a/rfcs/APIs/20220330_api_design_for_finfo.md b/rfcs/APIs/20220330_api_design_for_finfo.md
index 5c78bbcad..da3606982 100644
--- a/rfcs/APIs/20220330_api_design_for_finfo.md
+++ b/rfcs/APIs/20220330_api_design_for_finfo.md
@@ -2,11 +2,11 @@
| API名称 | paddle.finfo |
| ------------------------------------------------------------ | -------------------------------- |
-| 提交作者 | 林旭(isLinXu) |
-| 提交时间 | 2022-04-12 |
-| 版本号 | V2.0 |
+| 提交作者 | lisamhy,林旭(isLinXu) |
+| 提交时间 | 2022-04-12 |
+| 版本号 | V2.0 |
| 依赖飞桨版本 | develop |
-| 文件名 | 20220330_api-design_for_finfo.md |
+| 文件名 | 20220330_api-design_for_finfo.md |
# 一、概述
@@ -484,34 +484,112 @@ API设计为`paddle.finfo(dtype)`,根据选择计算方法(比如eps、max、m
通过设计实现与API对应的Class,并通过pybind将相应的成员函数绑定到python,从而实现该API。
-- `.h`头文件定义声明
+- `pybind.cc` finfo class 实现
```cpp
-namespace paddle {
-namespace pybind {
- void BindFinfoVarDsec(pybind11::module *m);
- void BindIinfoVarDsec(pybind11::module *m);
-}
-}
+struct finfo {
+ int64_t bits;
+ double eps;
+ double min; // lowest()
+ double max;
+ double tiny;
+ double smallest_normal; // min()
+ double resolution;
+ std::string dtype;
+
+ explicit finfo(const framework::proto::VarType::Type &type) {
+ switch (type) {
+ case framework::proto::VarType::FP16:
+ eps = std::numeric_limits::epsilon();
+ min = std::numeric_limits::lowest();
+ max = std::numeric_limits::max();
+ smallest_normal = std::numeric_limits::min();
+ tiny = smallest_normal;
+ resolution = std::pow(
+ 10, -std::numeric_limits::digits10);
+ bits = 16;
+ dtype = "float16";
+ break;
+ case framework::proto::VarType::FP32:
+ case framework::proto::VarType::COMPLEX64:
+ eps = std::numeric_limits::epsilon();
+ min = std::numeric_limits::lowest();
+ max = std::numeric_limits::max();
+ smallest_normal = std::numeric_limits::min();
+ tiny = smallest_normal;
+ resolution = std::pow(10, -std::numeric_limits::digits10);
+ bits = 32;
+ dtype = "float32";
+ break;
+ case framework::proto::VarType::FP64:
+ case framework::proto::VarType::COMPLEX128:
+ eps = std::numeric_limits::epsilon();
+ min = std::numeric_limits::lowest();
+ max = std::numeric_limits::max();
+ smallest_normal = std::numeric_limits::min();
+ tiny = smallest_normal;
+ resolution = std::pow(10, -std::numeric_limits::digits10);
+ bits = 64;
+ dtype = "float64";
+ break;
+ case framework::proto::VarType::BF16:
+ eps = std::numeric_limits::epsilon();
+ min = std::numeric_limits::lowest();
+ max = std::numeric_limits::max();
+ smallest_normal =
+ std::numeric_limits::min();
+ tiny = smallest_normal;
+ resolution = std::pow(
+ 10, -std::numeric_limits::digits10);
+ bits = 16;
+ dtype = "bfloat16";
+ break;
+ default:
+ PADDLE_THROW(platform::errors::InvalidArgument(
+ "the argument of paddle.finfo can only be paddle.float32, "
+ "paddle.float64, paddle.float16, paddle.bfloat16"
+ "paddle.complex64, or paddle.complex128"));
+ break;
+ }
+ }
+};
```
-- `.cc`绑定实现设计
+- `pybind.cc` finfo 绑定实现
```cpp
-
-void BindFInfoVarDsec(pybind11::module *m){
- pybind11::class_ finfo_var_desc(*m, "VarDesc", "");
- finfo_var_desc.def(pybind11::init())
- .def("bits", &pd::Tinfo::Bits)
- .def("eps", &pd::Tinfo::Eps)
- .def("min", &pd::Tinfo::Min)
- .def("max", &pd::Tinfo::Max)
- .def("tiny", &pd::Tinfo::Tiny)
- .def("resolution", &pd::Tinfo::Resolution)
-}
+ py::class_(m, "finfo")
+ .def(py::init())
+ .def_readonly("min", &finfo::min)
+ .def_readonly("max", &finfo::max)
+ .def_readonly("bits", &finfo::bits)
+ .def_readonly("eps", &finfo::eps)
+ .def_readonly("resolution", &finfo::resolution)
+ .def_readonly("smallest_normal", &finfo::smallest_normal)
+ .def_readonly("tiny", &finfo::tiny)
+ .def_readonly("dtype", &finfo::dtype)
+ .def("__repr__", [](const finfo &a) {
+ std::ostringstream oss;
+ oss << "paddle.finfo(min=" << a.min;
+ oss << ", max=" << a.max;
+ oss << ", eps=" << a.eps;
+ oss << ", resolution=" << a.resolution;
+ oss << ", smallest_normal=" << a.smallest_normal;
+ oss << ", tiny=" << a.tiny;
+ oss << ", bits=" << a.bits;
+ oss << ", dtype=" << a.dtype << ")";
+ return oss.str();
+ });
```
+- `dtype.py` python 暴露 finfo API
+```python
+from ..fluid.core import finfo as core_finfo
+
+def finfo(dtype):
+ return core_finfo(dtype)
+```
实现思路:
@@ -519,56 +597,7 @@ void BindFInfoVarDsec(pybind11::module *m){
- 因此要实现该API,需要如上抽象出一个符合要求的Class,同时并声明定义类下的成员函数来分别实现功能
-- 通过类的成员函数分别来实现eps、min、max等函数,通过Pybind11来进行接口与参数的绑定
-
-
-
-## API实现方案
-
-在paddle/fluid/framework/Info.h与Info.cc下新增实现函数
-定义class为`Tinfo`(借鉴Torch的结构设计,将finfo与iinfo合并为一个类进行实现)
-
-```c
-class Tinfo {
-public:
- int Bits(const at::ScalarType& type)
- float Eps(const at::ScalarType& type)
- float Min(const at::ScalarType& type)
- float Max(const at::ScalarType& type)
- float Tiny(const at::ScalarType& type)
- float Resolution(const at::ScalarType& type)
-}
-```
-
-`.cc`实现
-
-```cpp
-int Tinfo::Bits(const at::ScalarType& type){
- int bits = elementSize(self->type) * 8;
- return THPUtils_packInt64(bits);
-}
-
-float Tinfo::Eps(const at::ScalarType& type){
- return std::numeric_limits::type>::epsilon());
-}
-
-float Tinfo::Min(const at::ScalarType& type){
- return std::numeric_limits::type>::lowest());
-}
-
-float Tinfo::Max(const at::ScalarType& type){
- return std::numeric_limits::type>::max());
-}
-
-float Tinfo::Tiny(const at::ScalarType& type){
- return std::numeric_limits::type>::min());
-}
-
-float Tinfo::Resolution(const at::ScalarType& type){
- return std::numeric_limits::type>::resolution());
-
-}
-```
+- 通过类的成员函数分别来实现 eps、min、max、bits、resolution、tiny、smallest_normal、dtype 等函数,通过Pybind11来进行接口与参数的绑定