diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index e36b9bca93cd0..8320b87b1a81f 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -694,7 +694,8 @@ def parse_index(x, indices): ) estimated_dim = 0 - for dim, slice_item in enumerate(indices): + dim = 0 + for i, slice_item in enumerate(indices): start, end, step = None, None, None if is_integer_or_scalar_tensor(slice_item): if ( @@ -719,6 +720,7 @@ def parse_index(x, indices): start = slice_item step = 1 end = slice_item + 1 if slice_item != -1 else MAX_INTEGER + dim += 1 elif isinstance(slice_item, bool): # single bool is advanced-indexing none_axes.append(dim) @@ -733,7 +735,7 @@ def parse_index(x, indices): end = slice_item.stop step = slice_item.step estimated_dim += 1 - + dim += 1 if start is None and end is None and step is None: continue @@ -761,7 +763,7 @@ def parse_index(x, indices): has_advanced_index = True estimated_dim += 1 - + dim += 1 elif isinstance(slice_item, paddle.base.Variable): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. if ( @@ -781,7 +783,7 @@ def parse_index(x, indices): advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True estimated_dim += 1 - + dim += 1 elif isinstance(slice_item, paddle.pir.OpResult): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. if slice_item.dtype == paddle.pir.core.DataType.BOOL: @@ -798,7 +800,7 @@ def parse_index(x, indices): advanced_index[estimated_dim] = (estimated_dim, slice_item) has_advanced_index = True estimated_dim += 1 - + dim += 1 else: raise IndexError( "Valid index accept int / bool / slice / ellipsis / list / Tuple / Ndarray / Tensor, but received {}.".format( @@ -809,7 +811,7 @@ def parse_index(x, indices): starts.append(start) ends.append(end) steps.append(step) - axes.append(dim) + axes.append(dim - 1) use_strided_slice = ( True if ( diff --git a/test/indexing/test_getitem.py b/test/indexing/test_getitem.py index 19f4ce557950f..d8f95c7458c32 100644 --- a/test/indexing/test_getitem.py +++ b/test/indexing/test_getitem.py @@ -1146,8 +1146,3 @@ def test_bool_shape_error2(self): x = paddle.randn((4, 3, 2)) with self.assertRaises(IndexError): y = _getitem_static(x, (1, paddle.to_tensor([True, False]), [0, 1])) - - def test_bool_error(self): - x = paddle.randn((4, 3, 2)) - with self.assertRaises(ValueError): - y = _getitem_static(x, (True, [0, 1]))