Skip to content

Commit

Permalink
【PIR API adaptor No.181】Migrate repeat_interleave op into pir and ena…
Browse files Browse the repository at this point in the history
…ble cond test in pir (#59500)
  • Loading branch information
zrr1999 authored Dec 7, 2023
1 parent 0835df8 commit ff9a5f6
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 36 deletions.
7 changes: 3 additions & 4 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4958,14 +4958,13 @@ def repeat_interleave(x, repeats, axis=None, name=None):
Tensor(shape=[12], dtype=int64, place=Place(cpu), stop_gradient=True,
[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])
"""

if isinstance(repeats, Variable) and not repeats.shape:
if isinstance(repeats, (Variable, paddle.pir.Value)) and not repeats.shape:
repeats = paddle.reshape(repeats, [1])
if axis is None:
x = paddle.flatten(x)
axis = 0
if in_dynamic_mode():
if isinstance(repeats, Variable):
if in_dynamic_or_pir_mode():
if isinstance(repeats, (Variable, paddle.pir.Value)):
return _C_ops.repeat_interleave_with_tensor_index(x, repeats, axis)
return _C_ops.repeat_interleave(x, repeats, axis)

Expand Down
3 changes: 3 additions & 0 deletions test/legacy_test/test_linalg_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
from paddle import static
from paddle.pir_utils import test_with_pir_api

p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf)
p_list_m_n = (None, 2, -2)
Expand Down Expand Up @@ -82,6 +83,7 @@ def gen_empty_input():


class API_TestStaticCond(unittest.TestCase):
@test_with_pir_api
def test_out(self):
paddle.enable_static()
# test calling results of 'cond' in static graph mode
Expand Down Expand Up @@ -115,6 +117,7 @@ def test_dygraph_api_error(self):
x_tensor = paddle.to_tensor(x)
self.assertRaises(ValueError, paddle.linalg.cond, x_tensor, p)

@test_with_pir_api
def test_static_api_error(self):
paddle.enable_static()
# test raising errors when 'cond' is called in static graph mode
Expand Down
84 changes: 52 additions & 32 deletions test/legacy_test/test_repeat_interleave_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import paddle
from paddle import base
from paddle.base import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestRepeatInterleaveOp(OpTest):
Expand Down Expand Up @@ -58,10 +58,10 @@ def init_dtype_type(self):
self.index_size = self.x_shape[self.dim]

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestRepeatInterleaveOp2(OpTest):
Expand Down Expand Up @@ -96,10 +96,10 @@ def init_dtype_type(self):
self.index_size = self.x_shape[self.dim]

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestIndexSelectAPI(unittest.TestCase):
Expand All @@ -115,95 +115,115 @@ def input_data(self):
self.data_zero_dim_index = np.array(2)
self.data_index = np.array([0, 1, 2, 1]).astype('int32')

@test_with_pir_api
def test_repeat_interleave_api(self):
paddle.enable_static()
self.input_data()

# case 1:
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
x.desc.set_need_check_feed(False)
index = paddle.static.data(
name='repeats_',
shape=[4],
dtype='int32',
)
index.desc.set_need_check_feed(False)
if not paddle.framework.in_pir_mode():
x.desc.set_need_check_feed(False)
index.desc.set_need_check_feed(False)
x.stop_gradient = False
index.stop_gradient = False
z = paddle.repeat_interleave(x, index, axis=1)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x, 'repeats_': self.data_index},
fetch_list=[z.name],
fetch_list=[z],
return_numpy=False,
)
expect_out = np.repeat(self.data_x, self.data_index, axis=1)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

# case 2:
repeats = np.array([1, 2, 1]).astype('int32')
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 4], dtype="float32")
x.desc.set_need_check_feed(False)
index = paddle.static.data(
name='repeats_',
shape=[3],
dtype='int32',
)
index.desc.set_need_check_feed(False)
if not paddle.framework.in_pir_mode():
x.desc.set_need_check_feed(False)
index.desc.set_need_check_feed(False)
z = paddle.repeat_interleave(x, index, axis=0)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={
'x': self.data_x,
'repeats_': repeats,
},
fetch_list=[z.name],
fetch_list=[z],
return_numpy=False,
)
expect_out = np.repeat(self.data_x, repeats, axis=0)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

repeats = 2
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
x.desc.set_need_check_feed(False)
z = paddle.repeat_interleave(x, repeats, axis=0)
if not paddle.framework.in_pir_mode():
x.desc.set_need_check_feed(False)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x}, fetch_list=[z.name], return_numpy=False
feed={'x': self.data_x}, fetch_list=[z], return_numpy=False
)
expect_out = np.repeat(self.data_x, repeats, axis=0)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

# case 3 zero_dim:
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[-1], dtype="float32")
x.desc.set_need_check_feed(False)
z = paddle.repeat_interleave(x, repeats)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_zero_dim_x},
fetch_list=[z.name],
return_numpy=False,
)
expect_out = np.repeat(self.data_zero_dim_x, repeats)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
if not paddle.framework.in_pir_mode():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1], dtype="float32")
if not paddle.framework.in_pir_mode():
x.desc.set_need_check_feed(False)
z = paddle.repeat_interleave(x, repeats)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_zero_dim_x},
fetch_list=[z],
return_numpy=False,
)
expect_out = np.repeat(self.data_zero_dim_x, repeats)
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

# case 4 negative axis:
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
x.desc.set_need_check_feed(False)
index = paddle.static.data(
name='repeats_',
shape=[4],
dtype='int32',
)
index.desc.set_need_check_feed(False)

if not paddle.framework.in_pir_mode():
x.desc.set_need_check_feed(False)
index.desc.set_need_check_feed(False)
z = paddle.repeat_interleave(x, index, axis=-1)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x, 'repeats_': self.data_index},
fetch_list=[z.name],
fetch_list=[z],
return_numpy=False,
)
expect_out = np.repeat(self.data_x, self.data_index, axis=-1)
Expand Down

0 comments on commit ff9a5f6

Please sign in to comment.