diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 91a3887d66387..87cb258952f9e 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -4958,14 +4958,13 @@ def repeat_interleave(x, repeats, axis=None, name=None): Tensor(shape=[12], dtype=int64, place=Place(cpu), stop_gradient=True, [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]) """ - - if isinstance(repeats, Variable) and not repeats.shape: + if isinstance(repeats, (Variable, paddle.pir.Value)) and not repeats.shape: repeats = paddle.reshape(repeats, [1]) if axis is None: x = paddle.flatten(x) axis = 0 - if in_dynamic_mode(): - if isinstance(repeats, Variable): + if in_dynamic_or_pir_mode(): + if isinstance(repeats, (Variable, paddle.pir.Value)): return _C_ops.repeat_interleave_with_tensor_index(x, repeats, axis) return _C_ops.repeat_interleave(x, repeats, axis) diff --git a/test/legacy_test/test_linalg_cond.py b/test/legacy_test/test_linalg_cond.py index ffa8ef5777e93..a3b60c0ef6ded 100644 --- a/test/legacy_test/test_linalg_cond.py +++ b/test/legacy_test/test_linalg_cond.py @@ -18,6 +18,7 @@ import paddle from paddle import static +from paddle.pir_utils import test_with_pir_api p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf) p_list_m_n = (None, 2, -2) @@ -82,6 +83,7 @@ def gen_empty_input(): class API_TestStaticCond(unittest.TestCase): + @test_with_pir_api def test_out(self): paddle.enable_static() # test calling results of 'cond' in static graph mode @@ -115,6 +117,7 @@ def test_dygraph_api_error(self): x_tensor = paddle.to_tensor(x) self.assertRaises(ValueError, paddle.linalg.cond, x_tensor, p) + @test_with_pir_api def test_static_api_error(self): paddle.enable_static() # test raising errors when 'cond' is called in static graph mode diff --git a/test/legacy_test/test_repeat_interleave_op.py b/test/legacy_test/test_repeat_interleave_op.py index 8b2e6e20333f2..b2d0a12c6e260 100644 --- a/test/legacy_test/test_repeat_interleave_op.py +++ b/test/legacy_test/test_repeat_interleave_op.py @@ -19,7 +19,7 @@ import paddle from paddle import base -from paddle.base import Program, program_guard +from paddle.pir_utils import test_with_pir_api class TestRepeatInterleaveOp(OpTest): @@ -58,10 +58,10 @@ def init_dtype_type(self): self.index_size = self.x_shape[self.dim] def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestRepeatInterleaveOp2(OpTest): @@ -96,10 +96,10 @@ def init_dtype_type(self): self.index_size = self.x_shape[self.dim] def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestIndexSelectAPI(unittest.TestCase): @@ -115,25 +115,31 @@ def input_data(self): self.data_zero_dim_index = np.array(2) self.data_index = np.array([0, 1, 2, 1]).astype('int32') + @test_with_pir_api def test_repeat_interleave_api(self): paddle.enable_static() self.input_data() # case 1: - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32') - x.desc.set_need_check_feed(False) index = paddle.static.data( name='repeats_', shape=[4], dtype='int32', ) - index.desc.set_need_check_feed(False) + if not paddle.framework.in_pir_mode(): + x.desc.set_need_check_feed(False) + index.desc.set_need_check_feed(False) + x.stop_gradient = False + index.stop_gradient = False z = paddle.repeat_interleave(x, index, axis=1) exe = base.Executor(base.CPUPlace()) (res,) = exe.run( feed={'x': self.data_x, 'repeats_': self.data_index}, - fetch_list=[z.name], + fetch_list=[z], return_numpy=False, ) expect_out = np.repeat(self.data_x, self.data_index, axis=1) @@ -141,15 +147,18 @@ def test_repeat_interleave_api(self): # case 2: repeats = np.array([1, 2, 1]).astype('int32') - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 4], dtype="float32") - x.desc.set_need_check_feed(False) index = paddle.static.data( name='repeats_', shape=[3], dtype='int32', ) - index.desc.set_need_check_feed(False) + if not paddle.framework.in_pir_mode(): + x.desc.set_need_check_feed(False) + index.desc.set_need_check_feed(False) z = paddle.repeat_interleave(x, index, axis=0) exe = base.Executor(base.CPUPlace()) (res,) = exe.run( @@ -157,53 +166,64 @@ def test_repeat_interleave_api(self): 'x': self.data_x, 'repeats_': repeats, }, - fetch_list=[z.name], + fetch_list=[z], return_numpy=False, ) expect_out = np.repeat(self.data_x, repeats, axis=0) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) repeats = 2 - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32') - x.desc.set_need_check_feed(False) z = paddle.repeat_interleave(x, repeats, axis=0) + if not paddle.framework.in_pir_mode(): + x.desc.set_need_check_feed(False) exe = base.Executor(base.CPUPlace()) (res,) = exe.run( - feed={'x': self.data_x}, fetch_list=[z.name], return_numpy=False + feed={'x': self.data_x}, fetch_list=[z], return_numpy=False ) expect_out = np.repeat(self.data_x, repeats, axis=0) np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) # case 3 zero_dim: - with program_guard(Program(), Program()): - x = paddle.static.data(name='x', shape=[-1], dtype="float32") - x.desc.set_need_check_feed(False) - z = paddle.repeat_interleave(x, repeats) - exe = base.Executor(base.CPUPlace()) - (res,) = exe.run( - feed={'x': self.data_zero_dim_x}, - fetch_list=[z.name], - return_numpy=False, - ) - expect_out = np.repeat(self.data_zero_dim_x, repeats) - np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) + if not paddle.framework.in_pir_mode(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.static.data(name='x', shape=[-1], dtype="float32") + if not paddle.framework.in_pir_mode(): + x.desc.set_need_check_feed(False) + z = paddle.repeat_interleave(x, repeats) + exe = base.Executor(base.CPUPlace()) + (res,) = exe.run( + feed={'x': self.data_zero_dim_x}, + fetch_list=[z], + return_numpy=False, + ) + expect_out = np.repeat(self.data_zero_dim_x, repeats) + np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05) # case 4 negative axis: - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32') - x.desc.set_need_check_feed(False) index = paddle.static.data( name='repeats_', shape=[4], dtype='int32', ) - index.desc.set_need_check_feed(False) + + if not paddle.framework.in_pir_mode(): + x.desc.set_need_check_feed(False) + index.desc.set_need_check_feed(False) z = paddle.repeat_interleave(x, index, axis=-1) exe = base.Executor(base.CPUPlace()) (res,) = exe.run( feed={'x': self.data_x, 'repeats_': self.data_index}, - fetch_list=[z.name], + fetch_list=[z], return_numpy=False, ) expect_out = np.repeat(self.data_x, self.data_index, axis=-1)