From 854f6be6c70f2d092681edfdb26a83a1097d2c9a Mon Sep 17 00:00:00 2001 From: Kacper-W-Kozdon Date: Sun, 28 Jan 2024 19:45:32 +0100 Subject: [PATCH] ivy.unflatten --- .../array/experimental/manipulation.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/ivy/data_classes/array/experimental/manipulation.py b/ivy/data_classes/array/experimental/manipulation.py index 013af82658b0a..74dcc59192d7d 100644 --- a/ivy/data_classes/array/experimental/manipulation.py +++ b/ivy/data_classes/array/experimental/manipulation.py @@ -1134,6 +1134,64 @@ def take( self, indices, axis=axis, mode=mode, fill_value=fill_value, out=out ) + def unflatten( + x: Union[int, ivy.Array, ivy.NativeArray], + /, + *, + dim: Optional[int] = 0, + shape: Union[Tuple[int], ivy.Array, ivy.NativeArray], + out: Optional[ivy.Array] = None, + mode: str = "fill", + fill_value: Optional[Number] = None, + ) -> ivy.Array: + """ivy.Array instance method variant of ivy.unflatten. This method + simply wraps the function, and so the docstring for ivy.unflatten also + applies to this method with minimal changes. + + Parameters + ---------- + x + input array + shape + array indices. Must have an integer data type. + dim + axis over which to unflatten. If `axis` is negative, + the function must determine the axis along which to select values + by counting from the last dimension. + By default, the flattened input array is used. + out + optional output array, for writing the result to. It must + have a shape that the inputs broadcast to. + + Returns + ------- + ret + an array having the same data type as `x`. + The output array must have the same rank + (i.e., number of dimensions) as `x` and + must have the same shape as `x`, + except for the axis specified by `dim` + which is replaced with a tuple specified in `shape`. + + + Examples + -------- + With 'ivy.Array' input: + + >>> x = ivy.array([[1.2, 2.3, 3.4, 4.5], + [5.6, 6.7, 7.8, 8.9]]) + >>> shape = (2, 2) + >>> y = ivy.unflatten(x, shape=shape, dim=dim, out=y) + >>> print(y) + ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]]) + """ + return ivy.unflatten( + x, + dim=dim, + shape=shape, + out=out, + ) + def trim_zeros( self: ivy.Array, /,