From cf897b03ec59924167994b066226192bd2d1ea50 Mon Sep 17 00:00:00 2001 From: JYChen Date: Mon, 26 Feb 2024 11:23:20 +0800 Subject: [PATCH] fix shape error in combine-getitem (#61922) --- paddle/fluid/pybind/eager_method.cc | 10 ++++---- python/paddle/base/variable_index.py | 4 ++-- test/indexing/test_getitem.py | 34 ++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index d73f56cbdd212..584d1b8b58482 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1420,14 +1420,14 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self, if (pos_of_new_dim != 0) { std::vector perm(out.shape().size(), 0); - int tmp1 = pos_of_new_dim, tmp2 = 0, + int tmp1 = rank_of_new_dim, tmp2 = 0, tmp3 = pos_of_new_dim + rank_of_new_dim; for (int i = 0; i < static_cast(out.shape().size()); ++i) { - if (i < rank_of_new_dim) { + if (i < pos_of_new_dim) { perm[i] = - tmp1++; // range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim) - } else if (i >= rank_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) { - perm[i] = tmp2++; // range(0, pos_of_new_dim) + tmp1++; // range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim) + } else if (i >= pos_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) { + perm[i] = tmp2++; // range(0, rank_of_new_dim) } else { perm[i] = tmp3++; // range(pos_of_new_dim + rank_of_new_dim, out.ndim) } diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index 4851f190b996f..533d0360764a8 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -848,8 +848,8 @@ def _getitem_static(x, indices): if pos_of_new_dim != 0: perm = ( - list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim)) - + list(range(0, pos_of_new_dim)) + list(range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim)) + + list(range(0, rank_of_new_dim)) + list(range(pos_of_new_dim + rank_of_new_dim, out.ndim)) ) out = out.transpose(perm) diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py index f3a2374ecbe1d..3959bde43d152 100644 --- a/test/indexing/test_getitem.py +++ b/test/indexing/test_getitem.py @@ -233,6 +233,26 @@ def test_combined_index_11(self): np.testing.assert_allclose(y.numpy(), np_res) + def test_combined_index_12(self): + np_data = ( + np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) + ) + + if self.dtype == 'bfloat16': + np_data = convert_uint16_to_float(convert_float_to_uint16(np_data)) + if self.dtype == 'complex64' or self.dtype == 'complex128': + np_data = np_data + 1j * np_data + + np_res = np_data[:, :, [2, 4], :] + + x = paddle.to_tensor(np_data, dtype=self.dtype) + y = x[:, :, [2, 4], :] + + if self.dtype == 'bfloat16': + y = paddle.cast(y, dtype='float32') + + np.testing.assert_allclose(y.numpy(), np_res) + def test_index_has_range(self): np_data = ( np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype) @@ -970,6 +990,20 @@ def test_combined_index_11(self): np.testing.assert_allclose(res[0], np_res) + def test_combined_index_12(self): + np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)) + np_res = np_data[:, :, [2, 4], :] + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x = paddle.to_tensor(np_data) + y = _getitem_static( + x, (slice(None), slice(None), [2, 4], slice(None)) + ) + res = self.exe.run(fetch_list=[y]) + + np.testing.assert_allclose(res[0], np_res) + def test_index_has_range(self): # only one bool tensor with all False np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))