Skip to content

Commit

Permalink
fix: don't use iter() in _calculate_out_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Aug 28, 2024
1 parent 49bc204 commit 20c65a8
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions ivy/functional/ivy/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,34 @@
from ivy.utils.exceptions import handle_exceptions


def _calculate_out_shape(axis, array_shape):
def _calculate_out_shape(axis: Union[int, Sequence[int]], array_shape: Sequence[int]) -> List[int]:
"""
Calculate the output shape for expanding dimensions of an array.
Parameters
----------
axis : int or sequence of ints
The axis or axes along which to expand the shape.
array_shape : sequence of ints
The shape of the input array.
Returns
-------
list
The calculated output shape.
"""
if type(axis) not in (tuple, list):
axis = (axis,)
out_dims = len(axis) + len(array_shape)
norm_axis = normalize_axis_tuple(axis, out_dims)
shape_iter = iter(array_shape)
out_shape = [
1 if current_ax in norm_axis else next(shape_iter)
for current_ax in range(out_dims)
]
array_shape_index = 0
out_shape = []
for current_ax in range(out_dims):
if current_ax in norm_axis:
out_shape.append(1)
else:
out_shape.append(array_shape[array_shape_index])
array_shape_index += 1
return out_shape


Expand Down

0 comments on commit 20c65a8

Please sign in to comment.