Skip to content

Commit

Permalink
fix shape error in combine-getitem (PaddlePaddle#61922)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Feb 27, 2024
1 parent e94fc7a commit cf897b0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1420,14 +1420,14 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,

if (pos_of_new_dim != 0) {
std::vector<int> perm(out.shape().size(), 0);
int tmp1 = pos_of_new_dim, tmp2 = 0,
int tmp1 = rank_of_new_dim, tmp2 = 0,
tmp3 = pos_of_new_dim + rank_of_new_dim;
for (int i = 0; i < static_cast<int>(out.shape().size()); ++i) {
if (i < rank_of_new_dim) {
if (i < pos_of_new_dim) {
perm[i] =
tmp1++; // range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim)
} else if (i >= rank_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
perm[i] = tmp2++; // range(0, pos_of_new_dim)
tmp1++; // range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim)
} else if (i >= pos_of_new_dim && i < pos_of_new_dim + rank_of_new_dim) {
perm[i] = tmp2++; // range(0, rank_of_new_dim)
} else {
perm[i] = tmp3++; // range(pos_of_new_dim + rank_of_new_dim, out.ndim)
}
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,8 @@ def _getitem_static(x, indices):

if pos_of_new_dim != 0:
perm = (
list(range(pos_of_new_dim, pos_of_new_dim + rank_of_new_dim))
+ list(range(0, pos_of_new_dim))
list(range(rank_of_new_dim, pos_of_new_dim + rank_of_new_dim))
+ list(range(0, rank_of_new_dim))
+ list(range(pos_of_new_dim + rank_of_new_dim, out.ndim))
)
out = out.transpose(perm)
Expand Down
34 changes: 34 additions & 0 deletions test/indexing/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,26 @@ def test_combined_index_11(self):

np.testing.assert_allclose(y.numpy(), np_res)

def test_combined_index_12(self):
np_data = (
np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype)
)

if self.dtype == 'bfloat16':
np_data = convert_uint16_to_float(convert_float_to_uint16(np_data))
if self.dtype == 'complex64' or self.dtype == 'complex128':
np_data = np_data + 1j * np_data

np_res = np_data[:, :, [2, 4], :]

x = paddle.to_tensor(np_data, dtype=self.dtype)
y = x[:, :, [2, 4], :]

if self.dtype == 'bfloat16':
y = paddle.cast(y, dtype='float32')

np.testing.assert_allclose(y.numpy(), np_res)

def test_index_has_range(self):
np_data = (
np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6)).astype(self.ndtype)
Expand Down Expand Up @@ -970,6 +990,20 @@ def test_combined_index_11(self):

np.testing.assert_allclose(res[0], np_res)

def test_combined_index_12(self):
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, :, [2, 4], :]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(
x, (slice(None), slice(None), [2, 4], slice(None))
)
res = self.exe.run(fetch_list=[y])

np.testing.assert_allclose(res[0], np_res)

def test_index_has_range(self):
# only one bool tensor with all False
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
Expand Down

0 comments on commit cf897b0

Please sign in to comment.