-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PaddlePaddle hackathon] paddle.nn.PixelShuffle单测 #226
Merged
Merged
Changes from 29 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
5d6082b
paddle.nn.PixelShuffle单测提交
justld 2be3336
提交paddle.nn.PixelShuffle单测案例
justld cd161f4
add test of paddle.nn.ClipGradByGlobalNorm
justld b1af939
add test paddle.nn.ClipGradByNorm
justld 151bdb8
add test of paddle.nn.PixelShuffle
justld fa2c79e
add test of paddle.nn.ClipGradByGlobalNorm and paddle.nn.ClipGradByNorm
justld 7736df7
remove useless obj and class in test_clip_grad_by_global_norm.py and …
justld 710ef80
Merge branch 'develop' into develop
justld d315fde
add test of paddle.nn.UpsampingBinlinear2D
justld 0613680
remove unused code in test_flip_grad_by_global_norm.py
justld b7ce3bd
remove unused code in test_clip_grad_by_norm.py
justld 078620c
add code annotation in test_clip_grad_by_global_norm.py
justld eafb154
add code annotation in test_clip_grad_by_norm.py
justld 1978407
add code annotation in test_pixel_shuffle.py
justld 143384b
add annotation in test_upsampling_bilinear2D.py
justld 9d028b3
Merge branch 'develop' into develop
justld 6a61fef
Merge branch 'PaddlePaddle:develop' into develop
justld 3498837
add paddle.ClipGradByGlobalNorm test case
justld 6563226
Merge branch 'develop' of github.com:justld/PaddleTest into develop
justld dd88b27
add paddle.nn.ClipGradByNorm test case
justld 57b2dab
add paddle.nn.PixelShuffle test case
justld 7667896
add paddle.nn.UpsamplingBilinear2D test case
justld 54749f1
Merge branch 'develop' into develop
justld 4042681
fix bug in test_clip_grad_by_norm.py
justld 946fc10
Merge branch 'develop' of github.com:justld/PaddleTest into develop
justld fc0b49d
remove 3 test casse
justld 2c12d1e
fix annotation
justld 19edbf7
Merge branch 'develop' into develop
justld 57b4e8e
Merge branch 'develop' into develop
justld df80e6f
Merge branch 'PaddlePaddle:develop' into develop
justld 98f5d30
refine exception raise code
justld ddfebee
Merge branch 'develop' into develop
justld 3a2f853
Merge branch 'develop' into develop
justld e25f69c
Merge branch 'develop' into develop
justld 655823b
Merge branch 'develop' into develop
DDDivano e9ee331
Merge branch 'develop' into develop
DDDivano 1974736
Merge branch 'develop' into develop
DDDivano File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
#!/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python | ||
""" | ||
test_pixel_shuffle | ||
""" | ||
|
||
from apibase import APIBase | ||
from apibase import randtool, compare | ||
import paddle | ||
import pytest | ||
import numpy as np | ||
|
||
|
||
def pixel_shuffle_np(x, up_factor, data_format="NCHW"): | ||
""" | ||
pixel shuffle implemented by numpy. | ||
""" | ||
if data_format == "NCHW": | ||
n, c, h, w = x.shape | ||
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h, w) | ||
npresult = np.reshape(x, new_shape) | ||
npresult = npresult.transpose(0, 1, 4, 2, 5, 3) | ||
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor] | ||
npreslut = np.reshape(npresult, oshape) | ||
return npreslut | ||
else: | ||
n, h, w, c = x.shape | ||
new_shape = (n, h, w, c // (up_factor * up_factor), up_factor, up_factor) | ||
npresult = np.reshape(x, new_shape) | ||
npresult = npresult.transpose(0, 1, 4, 2, 5, 3) | ||
oshape = [n, h * up_factor, w * up_factor, c // (up_factor * up_factor)] | ||
npresult = np.reshape(npresult, oshape) | ||
return npresult | ||
|
||
|
||
class TestPixelShuffle(APIBase): | ||
""" | ||
test | ||
""" | ||
|
||
def hook(self): | ||
""" | ||
implement | ||
""" | ||
self.types = [np.float32, np.float64] | ||
|
||
|
||
obj = TestPixelShuffle(paddle.nn.PixelShuffle) | ||
|
||
|
||
@pytest.mark.api_nn_PixelShuffle_vartype | ||
def test_pixel_shuffle_base(): | ||
""" | ||
Test base. | ||
|
||
Test base config: | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
|
||
Expected Results: | ||
The output of pixel shuffle implemented by numpy and paddle should be equal. | ||
""" | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
res = pixel_shuffle_np(x, up_factor, data_format=data_format) | ||
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format) | ||
|
||
|
||
@pytest.mark.api_nn_PixelShuffle_parameters | ||
def test_pixel_shuffle_norm1(): | ||
""" | ||
Test pixel shuffle when input shape changes. | ||
|
||
Test Base config: | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
|
||
Changes: | ||
input shape: [2, 9, 4, 4] -> [4, 81, 4, 4] | ||
|
||
Expected Results: | ||
The output of pixel shuffle implemented by numpy and paddle should be equal. | ||
""" | ||
x = randtool("float", -10, 10, [4, 81, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
res = pixel_shuffle_np(x, up_factor, data_format=data_format) | ||
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format) | ||
|
||
|
||
@pytest.mark.api_nn_PixelShuffle_parameters | ||
def test_pixel_shuffle_norm2(): | ||
""" | ||
Test pixel shuffle when data_format changes. | ||
|
||
Test Base config: | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
|
||
Changes: | ||
input shape: [2, 9, 4, 4] -> [2, 4, 4, 9] | ||
data_format: 'NCHW' -> 'NHWC' | ||
|
||
Expected Results: | ||
The output of pixel shuffle implemented by numpy and paddle should be equal. | ||
""" | ||
x = randtool("float", -10, 10, [2, 4, 4, 9]) | ||
up_factor = 3 | ||
data_format = "NHWC" | ||
res = pixel_shuffle_np(x, up_factor, data_format=data_format) | ||
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format) | ||
|
||
|
||
@pytest.mark.api_nn_PixelShuffle_parameters | ||
def test_pixel_shuffle_norm3(): | ||
""" | ||
Test pixel shuffle when input data channels cann't be factorized by upscale_factor. | ||
|
||
Test Base config: | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
|
||
Changes: | ||
up_factor: 3 -> 4 | ||
|
||
Expected Results: | ||
when input data channels cann't be factorized by upscale_factor, raise ValueError. | ||
""" | ||
x = paddle.rand(shape=[2, 9, 4, 4]) | ||
up_factor = 4 | ||
data_format = "NCHW" | ||
with pytest.raises(ValueError): | ||
pixel_shuffle = paddle.nn.PixelShuffle(upscale_factor=up_factor, data_format=data_format) | ||
pixel_shuffle(x) | ||
|
||
|
||
@pytest.mark.api_nn_PixelShuffle_parameters | ||
def test_pixel_shuffle_norm4(): | ||
""" | ||
Test pixel shuffle when input data dtype changes. | ||
|
||
Test Base config: | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
|
||
Changes: | ||
input data dtype: float -> int | ||
|
||
Expected Results: | ||
The output of pixel shuffle implemented by numpy and paddle should be equal. | ||
""" | ||
x = randtool("int", -10, 10, [4, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
res = pixel_shuffle_np(x, up_factor, data_format=data_format) | ||
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format) | ||
|
||
|
||
@pytest.mark.api_nn_PixelShuffle_parameters | ||
def test_pixel_shuffle_norm5(): | ||
""" | ||
Test pixel shuffle when input value range changes. | ||
|
||
Test Base config: | ||
x = randtool("float", -10, 10, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
|
||
Changes: | ||
input value: range(-10, 10) -> range(-2555, 2555) | ||
|
||
Expected Results: | ||
The output of pixel shuffle implemented by numpy and paddle should be equal. | ||
""" | ||
x = randtool("float", -2555, 2555, [2, 9, 4, 4]) | ||
up_factor = 3 | ||
data_format = "NCHW" | ||
res = pixel_shuffle_np(x, up_factor, data_format=data_format) | ||
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要补充up_factor的值和输入x中的不能整除时的情况case。当前的case中up_factor(3)都是刚好能正常分解输入x。 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用apibase里面的异常类。需要判断paddle抛出异常时的报错信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,这里需要在实例化类的时候传递参数,然后运行该实例时才会引发异常,请问该使用apibase的哪个类呢。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exception方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,我已经改好了,请帮忙再review一下。