Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug poolop #54766

Merged
merged 1 commit into from
Jun 27, 2023
Merged

Conversation

heavyrain-lzy
Copy link
Contributor

@heavyrain-lzy heavyrain-lzy commented Jun 19, 2023

PR types

Bug fixes

PR changes

OPs

Description

Pcard-67001
修复 #54070 引起的精度问题。由于将use_cudnn移动到extra中,导致保存模型时忽略该参数,但是对比于其他算子同类参数,这个算子的区别在于:Python的接口默认use_cudnntrue,C++端的参数默认是false,导致动转静出现问题。
测试代码:

import numpy as np
import paddle

paddle.seed(33)
np.random.seed(33)


def randtool(dtype, low, high, shape):
    """
    np random tools
    """
    if dtype == "int":
        return np.random.randint(low, high, shape)

    elif dtype == "float":
        return low + (high - low) * np.random.random(shape)


class BuildClass(paddle.nn.Layer):
    """
    用于动转静的nn.Layer
    """

    def __init__(self, in_params, func):
        super(BuildClass, self).__init__()
        self.func = eval(func)(**in_params)

    def forward(self, input):
        """
        forward
        """
        x = self.func(input)
        return x


# class BuildJitFunc(paddle.nn.Layer):
#     """
#     用于动转静的nn.Layer
#     """
#
#     def __init__(self, in_params, func):
#         super(BuildJitFunc, self).__init__()
#         paddle.seed(33)
#         self.func = eval(func)
#         self._params = in_params
#
#     @paddle.jit.to_static
#     def forward(self, inputs):
#         """
#         forward
#         """
#         x = self.func(inputs, **self._params)
#         return x


func = "paddle.nn.AvgPool2D"

in_tensor = {"x": paddle.to_tensor(randtool("float", -10, 10, shape=[2, 3, 4, 4]), dtype="float32")}

in_params = {
      "kernel_size": [3, 3],
      "stride": [3, 3],
      "padding": [0, 0, 0, 0],
      "ceil_mode": True,
      "exclusive": False,
}

obj = BuildClass(in_params, func)

obj.eval()

paddle.seed(33)
dy_out = obj(in_tensor["x"])
# print("dy_out is: ", dy_out)

jit_obj = paddle.jit.to_static(obj)
print("jit_obj is created successfully !!!")

paddle.seed(33)
st_out = jit_obj(in_tensor["x"])
# print("st_out is: ", st_out)

paddle.seed(33)
paddle.jit.save(jit_obj, path="avgpool2d")

paddle.seed(33)
jit = paddle.jit.load("avgpool2d")

paddle.seed(33)
res = jit(in_tensor["x"])
# print('jit.load res: ', res)

np.testing.assert_allclose(actual=res, desired=dy_out, atol=1e-5, rtol=1e-6, equal_nan=True)

@paddle-bot
Copy link

paddle-bot bot commented Jun 19, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

YuanRisheng
YuanRisheng previously approved these changes Jun 20, 2023
Copy link
Contributor

@zyfncg zyfncg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不直接把op_compat里的use_cudnn去掉就可以了?现在这样改会增加不少复杂度

@heavyrain-lzy
Copy link
Contributor Author

这个是不直接把op_compat里的use_cudnn去掉就可以了?现在这样改会增加不少复杂度

但是原始的use_cudnn在Maker中被标记为AsExtra(),直接去掉会不会有兼容性问题?

@zyfncg
Copy link
Contributor

zyfncg commented Jun 20, 2023

这个是不直接把op_compat里的use_cudnn去掉就可以了?现在这样改会增加不少复杂度

但是原始的use_cudnn在Maker中被标记为AsExtra(),直接去掉会不会有兼容性问题?

Maker中AsExtra现在没什么作用,去掉不会有影响

@heavyrain-lzy
Copy link
Contributor Author

这个是不直接把op_compat里的use_cudnn去掉就可以了?现在这样改会增加不少复杂度

但是原始的use_cudnn在Maker中被标记为AsExtra(),直接去掉会不会有兼容性问题?

Maker中AsExtra现在没什么作用,去掉不会有影响

OK,那我恢复一下

@heavyrain-lzy heavyrain-lzy merged commit 689e27a into PaddlePaddle:develop Jun 27, 2023
heavyrain-lzy added a commit to heavyrain-lzy/Paddle that referenced this pull request Jul 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants