-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PaddlePaddle hackathon] paddle.nn.PixelShuffle单测 #226
Conversation
self.types = [np.float32] | ||
|
||
|
||
obj = TestClipGradByNorm(paddle.nn.ClipGradByNorm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obj定义了没有用到
…test_clip_grad_by_norm.py, modify test_pixel_shuffle.py
PR types:New features PR types:New features PR types:New features |
PR types:New features PR types:New features PR types:New features PR types:New features |
hi,几个任务尽量别放在一个PR里,一个Task对应一个PR,避免Task间的相互影响合入,以后的PR需要注意下 谢谢~ |
def test_clip_grad_by_global_norm1(): | ||
""" | ||
input shape | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc部分详细描述下该测试用例的验证点,测试输入,预期结果等,通过doc部分可以清除的知晓该用例的测试点。,其他的case也类似
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,已经添加了每个测试案例的相关注释
您好,已经添加了测试部分的注释 |
|
||
# compare grad value computed by numpy and paddle | ||
for res, p_res in zip(np_res, paddle_res): | ||
compare(res[1], p_res[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要完善测试case
1、尝试不同输入类型下,目前我看都是float32,需要验证其他数据类型下是否符合预期。
2、输入shape为一维、
3、该类的功能中说明(将一个 Tensor列表 t_list 中所有Tensor的L2范数之和,限定在 clip_norm 范围内),case中需要显示验证这点。
up_factor = 3 | ||
data_format = "NCHW" | ||
res = pixel_shuffle_np(x, up_factor, data_format=data_format) | ||
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要补充up_factor的值和输入x中的不能整除时的情况case。当前的case中up_factor(3)都是刚好能正常分解输入x。
tensor_x = paddle.to_tensor(x) | ||
paddle_res = paddle.nn.UpsamplingBilinear2D(size=size, scale_factor=scale_factor, data_format=data_format)(tensor_x) | ||
paddle_res = paddle_res.numpy() | ||
compare(res, paddle_res, delta, rtol) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size (list|tuple|Tensor|None) - 输出Tensor,输入为4D张量,形状为为(out_h, out_w)的2-D Tensor。如果 size 是列表,每一个元素可以是整数或者形状为[1]的变量。如果 size 是变量,则其维度大小为1。默认值为None。
scale_factor (float|Tensor|list|tuple|None)-输入的高度或宽度的乘数因子。 size 和 scale_factor 至少要设置一个。 size 的优先级高于 scale_factor 。默认值为None。如果 scale_factor 是一个list或tuple,它必须与输入的shape匹配。
针对类的不同参数添加更丰富的case,不同的输入类型,shape等、异常情况等
data_format = "NCHW" | ||
with pytest.raises(ValueError): | ||
pixel_shuffle = paddle.nn.PixelShuffle(upscale_factor=up_factor, data_format=data_format) | ||
pixel_shuffle(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用apibase里面的异常类。需要判断paddle抛出异常时的报错信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,这里需要在实例化类的时候传递参数,然后运行该实例时才会引发异常,请问该使用apibase的哪个类呢。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exception方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,我已经改好了,请帮忙再review一下。
PR types
New features
PR changes
添加/framwork/api/nn/test_pixel_shuffle.py
Describe
Task: #35904
添加paddle.nn.PixelShuffle单测