Skip to content

Commit

Permalink
[PIR]Migrate AdaptiveAvgPool2D into pir (PaddlePaddle#58138)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored and jiahy0825 committed Oct 26, 2023
1 parent eea81c0 commit 37c061c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 39 deletions.
6 changes: 3 additions & 3 deletions python/paddle/nn/functional/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,8 +1651,9 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
elif _contain_var(output_size):
output_size = _convert_to_tensor_list(output_size)

if in_dygraph_mode():
x = x._use_gpudnn(False)
if in_dynamic_or_pir_mode():
if in_dygraph_mode():
x = x._use_gpudnn(False)
return _C_ops.pool2d(
x,
output_size,
Expand All @@ -1666,7 +1667,6 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
True,
"EXPLICIT",
)

else:
l_type = 'pool2d'
check_variable_and_dtype(
Expand Down
89 changes: 53 additions & 36 deletions test/legacy_test/test_adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from test_attribute_var import UnittestBase

import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.pir_utils import test_with_pir_api


def adaptive_start_index(index, input_size, output_size):
Expand Down Expand Up @@ -113,37 +113,45 @@ def setUp(self):
x=self.x_np, output_size=[None, 3], pool_type="avg"
)

@test_with_pir_api
def test_static_graph(self):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.static.data(
name="x", shape=[2, 3, 7, 7], dtype="float32"
)

out_1 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3]
)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()

out_2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=5)
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name="x", shape=[2, 3, 7, 7], dtype="float32"
)

out_3 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[2, 5]
)
out_1 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3]
)

out_4 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3], data_format="NHWC"
)
out_2 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=5
)

out_5 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[None, 3]
)
out_3 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[2, 5]
)

out_4 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3], data_format="NHWC"
)

out_5 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[None, 3]
)

exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_4, res_5] = exe.run(
base.default_main_program(),
main_program,
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_4, out_5],
)
Expand Down Expand Up @@ -232,38 +240,47 @@ def setUp(self):
x=self.x_np, output_size=[None, 3], pool_type="avg"
)

@test_with_pir_api
def test_static_graph(self):
for use_cuda in (
[False, True] if core.is_compiled_with_cuda() else [False]
):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.static.data(
name="x", shape=[2, 3, 7, 7], dtype="float32"
)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()

adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=[3, 3])
out_1 = adaptive_avg_pool(x=x)
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name="x", shape=[2, 3, 7, 7], dtype="float32"
)

adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=5)
out_2 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(
output_size=[3, 3]
)
out_1 = adaptive_avg_pool(x=x)

adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=[2, 5])
out_3 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(output_size=5)
out_2 = adaptive_avg_pool(x=x)

adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(
output_size=[3, 3], data_format="NHWC"
)
out_4 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(
output_size=[2, 5]
)
out_3 = adaptive_avg_pool(x=x)

adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(
output_size=[None, 3]
)
out_5 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(
output_size=[3, 3], data_format="NHWC"
)
out_4 = adaptive_avg_pool(x=x)

adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2D(
output_size=[None, 3]
)
out_5 = adaptive_avg_pool(x=x)

exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_4, res_5] = exe.run(
base.default_main_program(),
main_program,
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_4, out_5],
)
Expand Down

0 comments on commit 37c061c

Please sign in to comment.