Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle hackathon] paddle.nn.PixelShuffle单测 #226

Merged
merged 37 commits into from
Oct 25, 2021
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5d6082b
paddle.nn.PixelShuffle单测提交
justld Oct 10, 2021
2be3336
提交paddle.nn.PixelShuffle单测案例
justld Oct 10, 2021
cd161f4
add test of paddle.nn.ClipGradByGlobalNorm
justld Oct 10, 2021
b1af939
add test paddle.nn.ClipGradByNorm
justld Oct 10, 2021
151bdb8
add test of paddle.nn.PixelShuffle
justld Oct 10, 2021
fa2c79e
add test of paddle.nn.ClipGradByGlobalNorm and paddle.nn.ClipGradByNorm
justld Oct 10, 2021
7736df7
remove useless obj and class in test_clip_grad_by_global_norm.py and …
justld Oct 11, 2021
710ef80
Merge branch 'develop' into develop
justld Oct 11, 2021
d315fde
add test of paddle.nn.UpsampingBinlinear2D
justld Oct 11, 2021
0613680
remove unused code in test_flip_grad_by_global_norm.py
justld Oct 11, 2021
b7ce3bd
remove unused code in test_clip_grad_by_norm.py
justld Oct 11, 2021
078620c
add code annotation in test_clip_grad_by_global_norm.py
justld Oct 12, 2021
eafb154
add code annotation in test_clip_grad_by_norm.py
justld Oct 12, 2021
1978407
add code annotation in test_pixel_shuffle.py
justld Oct 12, 2021
143384b
add annotation in test_upsampling_bilinear2D.py
justld Oct 12, 2021
9d028b3
Merge branch 'develop' into develop
justld Oct 12, 2021
6a61fef
Merge branch 'PaddlePaddle:develop' into develop
justld Oct 13, 2021
3498837
add paddle.ClipGradByGlobalNorm test case
justld Oct 13, 2021
6563226
Merge branch 'develop' of github.com:justld/PaddleTest into develop
justld Oct 13, 2021
dd88b27
add paddle.nn.ClipGradByNorm test case
justld Oct 13, 2021
57b2dab
add paddle.nn.PixelShuffle test case
justld Oct 13, 2021
7667896
add paddle.nn.UpsamplingBilinear2D test case
justld Oct 13, 2021
54749f1
Merge branch 'develop' into develop
justld Oct 13, 2021
4042681
fix bug in test_clip_grad_by_norm.py
justld Oct 13, 2021
946fc10
Merge branch 'develop' of github.com:justld/PaddleTest into develop
justld Oct 13, 2021
fc0b49d
remove 3 test casse
justld Oct 13, 2021
2c12d1e
fix annotation
justld Oct 14, 2021
19edbf7
Merge branch 'develop' into develop
justld Oct 14, 2021
57b4e8e
Merge branch 'develop' into develop
justld Oct 14, 2021
df80e6f
Merge branch 'PaddlePaddle:develop' into develop
justld Oct 15, 2021
98f5d30
refine exception raise code
justld Oct 15, 2021
ddfebee
Merge branch 'develop' into develop
justld Oct 16, 2021
3a2f853
Merge branch 'develop' into develop
justld Oct 18, 2021
e25f69c
Merge branch 'develop' into develop
justld Oct 19, 2021
655823b
Merge branch 'develop' into develop
DDDivano Oct 22, 2021
e9ee331
Merge branch 'develop' into develop
DDDivano Oct 22, 2021
1974736
Merge branch 'develop' into develop
DDDivano Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions framework/api/nn/test_pixel_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/bin/env python
# -*- coding: utf-8 -*-
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
"""
test_pixel_shuffle
"""

from apibase import APIBase
from apibase import randtool, compare
import paddle
import pytest
import numpy as np


def pixel_shuffle_np(x, up_factor, data_format="NCHW"):
"""
pixel shuffle implemented by numpy.
"""
if data_format == "NCHW":
n, c, h, w = x.shape
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h, w)
npresult = np.reshape(x, new_shape)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
npreslut = np.reshape(npresult, oshape)
return npreslut
else:
n, h, w, c = x.shape
new_shape = (n, h, w, c // (up_factor * up_factor), up_factor, up_factor)
npresult = np.reshape(x, new_shape)
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
oshape = [n, h * up_factor, w * up_factor, c // (up_factor * up_factor)]
npresult = np.reshape(npresult, oshape)
return npresult


class TestPixelShuffle(APIBase):
"""
test
"""

def hook(self):
"""
implement
"""
self.types = [np.float32, np.float64]


obj = TestPixelShuffle(paddle.nn.PixelShuffle)


@pytest.mark.api_nn_PixelShuffle_vartype
def test_pixel_shuffle_base():
"""
Test base.

Test base config:
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"

Expected Results:
The output of pixel shuffle implemented by numpy and paddle should be equal.
"""
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"
res = pixel_shuffle_np(x, up_factor, data_format=data_format)
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format)


@pytest.mark.api_nn_PixelShuffle_parameters
def test_pixel_shuffle_norm1():
"""
Test pixel shuffle when input shape changes.

Test Base config:
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"

Changes:
input shape: [2, 9, 4, 4] -> [4, 81, 4, 4]

Expected Results:
The output of pixel shuffle implemented by numpy and paddle should be equal.
"""
x = randtool("float", -10, 10, [4, 81, 4, 4])
up_factor = 3
data_format = "NCHW"
res = pixel_shuffle_np(x, up_factor, data_format=data_format)
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format)


@pytest.mark.api_nn_PixelShuffle_parameters
def test_pixel_shuffle_norm2():
"""
Test pixel shuffle when data_format changes.

Test Base config:
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"

Changes:
input shape: [2, 9, 4, 4] -> [2, 4, 4, 9]
data_format: 'NCHW' -> 'NHWC'

Expected Results:
The output of pixel shuffle implemented by numpy and paddle should be equal.
"""
x = randtool("float", -10, 10, [2, 4, 4, 9])
up_factor = 3
data_format = "NHWC"
res = pixel_shuffle_np(x, up_factor, data_format=data_format)
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format)


@pytest.mark.api_nn_PixelShuffle_parameters
def test_pixel_shuffle_norm3():
"""
Test pixel shuffle when input data channels cann't be factorized by upscale_factor.

Test Base config:
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"

Changes:
up_factor: 3 -> 4

Expected Results:
when input data channels cann't be factorized by upscale_factor, raise ValueError.
"""
x = paddle.rand(shape=[2, 9, 4, 4])
up_factor = 4
data_format = "NCHW"
obj.exception(ValueError, mode="python", data=x, upscale_factor=up_factor, data_format=data_format)


@pytest.mark.api_nn_PixelShuffle_parameters
def test_pixel_shuffle_norm4():
"""
Test pixel shuffle when input data dtype changes.

Test Base config:
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"

Changes:
input data dtype: float -> int

Expected Results:
The output of pixel shuffle implemented by numpy and paddle should be equal.
"""
x = randtool("int", -10, 10, [4, 9, 4, 4])
up_factor = 3
data_format = "NCHW"
res = pixel_shuffle_np(x, up_factor, data_format=data_format)
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format)


@pytest.mark.api_nn_PixelShuffle_parameters
def test_pixel_shuffle_norm5():
"""
Test pixel shuffle when input value range changes.

Test Base config:
x = randtool("float", -10, 10, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"

Changes:
input value: range(-10, 10) -> range(-2555, 2555)

Expected Results:
The output of pixel shuffle implemented by numpy and paddle should be equal.
"""
x = randtool("float", -2555, 2555, [2, 9, 4, 4])
up_factor = 3
data_format = "NCHW"
res = pixel_shuffle_np(x, up_factor, data_format=data_format)
obj.run(res=res, data=x, upscale_factor=up_factor, data_format=data_format)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要补充up_factor的值和输入x中的不能整除时的情况case。当前的case中up_factor(3)都是刚好能正常分解输入x。