Skip to content

Commit

Permalink
fix ci error
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Nov 9, 2023
1 parent 6d170bb commit 5c1e797
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
14 changes: 8 additions & 6 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ def parse_index(x, indices):
)

estimated_dim = 0
for dim, slice_item in enumerate(indices):
dim = 0
for i, slice_item in enumerate(indices):
start, end, step = None, None, None
if is_integer_or_scalar_tensor(slice_item):
if (
Expand All @@ -719,6 +720,7 @@ def parse_index(x, indices):
start = slice_item
step = 1
end = slice_item + 1 if slice_item != -1 else MAX_INTEGER
dim += 1
elif isinstance(slice_item, bool):
# single bool is advanced-indexing
none_axes.append(dim)
Expand All @@ -733,7 +735,7 @@ def parse_index(x, indices):
end = slice_item.stop
step = slice_item.step
estimated_dim += 1

dim += 1
if start is None and end is None and step is None:
continue

Expand Down Expand Up @@ -761,7 +763,7 @@ def parse_index(x, indices):

has_advanced_index = True
estimated_dim += 1

dim += 1
elif isinstance(slice_item, paddle.base.Variable):
# In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
if (
Expand All @@ -781,7 +783,7 @@ def parse_index(x, indices):
advanced_index[estimated_dim] = (estimated_dim, slice_item)
has_advanced_index = True
estimated_dim += 1

dim += 1
elif isinstance(slice_item, paddle.pir.OpResult):
# In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing.
if slice_item.dtype == paddle.pir.core.DataType.BOOL:
Expand All @@ -798,7 +800,7 @@ def parse_index(x, indices):
advanced_index[estimated_dim] = (estimated_dim, slice_item)
has_advanced_index = True
estimated_dim += 1

dim += 1
else:
raise IndexError(
"Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format(
Expand All @@ -809,7 +811,7 @@ def parse_index(x, indices):
starts.append(start)
ends.append(end)
steps.append(step)
axes.append(dim)
axes.append(dim - 1)
use_strided_slice = (
True
if (
Expand Down
5 changes: 0 additions & 5 deletions test/indexing/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,8 +1146,3 @@ def test_bool_shape_error2(self):
x = paddle.randn((4, 3, 2))
with self.assertRaises(IndexError):
y = _getitem_static(x, (1, paddle.to_tensor([True, False]), [0, 1]))

def test_bool_error(self):
x = paddle.randn((4, 3, 2))
with self.assertRaises(ValueError):
y = _getitem_static(x, (True, [0, 1]))

0 comments on commit 5c1e797

Please sign in to comment.