Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added Ivy.unflatten #28079

Merged
merged 50 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
53b5db7
ivy.unflatten
Kacper-W-Kozdon Jan 28, 2024
854f6be
ivy.unflatten
Kacper-W-Kozdon Jan 28, 2024
8272232
ivy.unflatten
Kacper-W-Kozdon Jan 28, 2024
0b0115b
Merge branch 'unifyai:main' into ivy_unflatten
Kacper-W-Kozdon Jan 28, 2024
33b07f9
Added test for ivy.unflatten() functional->experimental->core->manipu…
Kacper-W-Kozdon Feb 1, 2024
640dfe9
ivy.unflatten() test
Kacper-W-Kozdon Feb 1, 2024
e50faee
ivy.unflatten() test
Kacper-W-Kozdon Feb 1, 2024
7ebb1f6
ivy.unflatten() test
Kacper-W-Kozdon Feb 1, 2024
48e965c
ivy.unflatten() test
Kacper-W-Kozdon Feb 1, 2024
156024b
ivy.unflatten() test
Kacper-W-Kozdon Feb 1, 2024
581d7f5
ivy.unflatten() test
Kacper-W-Kozdon Feb 1, 2024
015d3d1
ivy.unflatten() test
Kacper-W-Kozdon Feb 2, 2024
beea8ff
ivy.unflatten() test + fixed Container method
Kacper-W-Kozdon Feb 2, 2024
e61aac9
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
7e02acf
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
260b709
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
eff6fb6
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
c1a97ea
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
1811a38
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
e8b7c6c
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
69bd268
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
ae40a14
ivy.unflatten() _static_unflatten() fixing
Kacper-W-Kozdon Feb 2, 2024
923ed52
main pull
Kacper-W-Kozdon Feb 2, 2024
5c8de20
Merge branch 'main' into ivy_unflatten
Kacper-W-Kozdon Feb 2, 2024
7a258cc
Add files via upload
Kacper-W-Kozdon Feb 2, 2024
740622f
Add files via upload
Kacper-W-Kozdon Feb 2, 2024
df73e40
Delete demos directory
Kacper-W-Kozdon Feb 2, 2024
334efbe
remove demos from the commit
Kacper-W-Kozdon Feb 2, 2024
6c92be1
remove demos from the staged commits
Kacper-W-Kozdon Feb 2, 2024
c2f2441
ivy.unflatten test, container_flags = False
Kacper-W-Kozdon Feb 3, 2024
69f44d3
fixing tests for unflatten
Kacper-W-Kozdon Feb 3, 2024
cab1409
fixing tests for unflatten
Kacper-W-Kozdon Feb 3, 2024
425a313
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
01c9eb7
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
41d09c2
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
f8e2221
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
0bc2766
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
eb3a104
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
1ab5152
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
1137015
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
82cac32
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
be30d6b
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
cb87129
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
1ab2dfa
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
7e96bbb
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
8e21032
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
9646b50
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
b22c235
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
edbe4fe
ivy_unflatten fixing errors with shape and dim for array
Kacper-W-Kozdon Feb 3, 2024
d05805d
corrections in arg type for review
Kacper-W-Kozdon Feb 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions ivy/data_classes/array/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,62 @@ def take(
self, indices, axis=axis, mode=mode, fill_value=fill_value, out=out
)

def unflatten(
self: ivy.Array,
/,
shape: Union[Tuple[int], ivy.Array, ivy.NativeArray],
dim: Optional[int] = 0,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.unflatten. This method
simply wraps the function, and so the docstring for ivy.unflatten also
applies to this method with minimal changes.

Parameters
----------
self
input array
shape
array indices. Must have an integer data type.
dim
axis over which to unflatten. If `axis` is negative,
the function must determine the axis along which to select values
by counting from the last dimension.
By default, the flattened input array is used.
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.

Returns
-------
ret
an array having the same data type as `x`.
The output array must have the same rank
(i.e., number of dimensions) as `x` and
must have the same shape as `x`,
except for the axis specified by `dim`
which is replaced with a tuple specified in `shape`.


Examples
--------
With 'ivy.Array' input:

>>> x = ivy.array([[1.2, 2.3, 3.4, 4.5],
[5.6, 6.7, 7.8, 8.9]])
>>> shape = (2, 2)
>>> y = x.unflatten(shape=shape, dim=dim, out=y)
>>> print(y)
ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]])
"""
return ivy.unflatten(
self._data,
shape=shape,
dim=dim,
out=out,
)

def trim_zeros(
self: ivy.Array,
/,
Expand Down
182 changes: 178 additions & 4 deletions ivy/data_classes/container/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4065,6 +4065,180 @@ def trim_zeros(
"""
return self._static_trim_zeros(self, trim=trim)

@staticmethod
def _static_unflatten(
x: Union[int, ivy.Array, ivy.NativeArray, ivy.Container],
/,
shape: Union[Tuple[int], ivy.Array, ivy.NativeArray, ivy.Container],
dim: Optional[Union[int, ivy.Container]] = 0,
*,
out: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""ivy.Container static method variant of ivy.unflatten. This method
simply wraps the function, and so the docstring for ivy.unflatten also
applies to this method with minimal changes.

Parameters
----------
x
input array
shape
array indices. Must have an integer data type.
dim
axis over which to select values. If `axis` is negative,
the function must determine the axis along which to select values
by counting from the last dimension.
By default, the flattened input array is used.
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
key_chains
The key-chains to apply or not apply the method to.
Default is ``None``.
to_apply
If True, the method will be applied to key_chains,
otherwise key_chains will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was
not applied. Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.

Returns
-------
ret
an array having the same data type as `x`.
The output array must have the same rank
(i.e., number of dimensions) as `x` and
must have the same shape as `x`,
except for the axis specified by `axis`
whose size must equal the number of elements in `indices`.


Examples
--------
With 'ivy.Container' input:

>>> x = ivy.Container(a = ivy.array([[True, False, False, True],
[False, True, False, True]])),
... b = ivy.array([[1.2, 2.3, 3.4, 4.5],
[5.6, 6.7, 7.8, 8.9]]),
... c = ivy.array([[1, 2, 3, 4],
[5, 6, 7, 8]]))
>>> dim = 1
>>> shape = (2, 2)
>>> y = ivy.Container._static_unflatten(x, shape=shape, dim=dim)
>>> print(y)
{
a: ivy.array([[[True, False], [False, True]],
[[False, True], [False, True]]])
b: ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]])
c: ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
}
"""
return ContainerBase.cont_multi_map_in_function(
"unflatten",
x,
shape=shape,
dim=dim,
out=out,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def unflatten(
self: ivy.Container,
/,
shape: Union[Tuple[int], ivy.Array, ivy.NativeArray, ivy.Container],
dim: Optional[Union[int, ivy.Container]] = 0,
*,
out: Optional[Union[ivy.Array, ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.unflatten. This method
simply wraps the function, and so the docstring for ivy.unflatten also
applies to this method with minimal changes.

Parameters
----------
self
input array
shape
array indices. Must have an integer data type.
dim
axis over which to unflatten. If `axis` is negative,
the function must determine the axis along which to select values
by counting from the last dimension.
By default, the flattened input array is used.
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
key_chains
The key-chains to apply or not apply the method to.
Default is ``None``.
to_apply
If True, the method will be applied to key_chains,
otherwise key_chains will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was
not applied. Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.

Returns
-------
ret
an array having the same data type as `x`.
The output array must have the same rank
(i.e., number of dimensions) as `x` and
must have the same shape as `x`,
except for the axis specified by `dim`
which is replaced with a tuple specified in `shape`.


Examples
--------
With 'ivy.Container' input:

>>> x = ivy.Container(a = ivy.array([[True, False, False, True],
[False, True, False, True]])),
... b = ivy.array([[1.2, 2.3, 3.4, 4.5],
[5.6, 6.7, 7.8, 8.9]]),
... c = ivy.array([[1, 2, 3, 4],
[5, 6, 7, 8]]))
>>> dim = 1
>>> shape = (2, 2)
>>> y = x.unflatten(shape=shape, dim=dim)
>>> print(y)
{
a: ivy.array([[[True, False], [False, True]],
[[False, True], [False, True]]])
b: ivy.array([[[1.2, 2.3], [3.4, 4.5]], [[5.6, 6.7], [7.8, 8.9]]])
c: ivy.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
}
"""
return self._static_unflatten(
self,
shape=shape,
dim=dim,
out=out,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)


def concat_from_sequence(
self: ivy.Container,
Expand Down Expand Up @@ -4130,11 +4304,11 @@ def concat_from_sequence(
>>> print(z)
{
'a': ivy.array([[[0, 1],
[3, 2]],
[[2, 3],
[1, 0]]]),
[3, 2]],
[[2, 3],
[1, 0]]]),
'b': ivy.array([[[4, 5],
[1, 0]]])
[1, 0]]])
}
"""
new_input_sequence = (
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ def trim_zeros(a: JaxArray, /, *, trim: Optional[str] = "bf") -> JaxArray:
def unflatten(
x: JaxArray,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: int = 0,
*,
out: Optional[JaxArray] = None,
order: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/numpy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def put_along_axis(
def unflatten(
x: np.ndarray,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: Optional[int] = 0,
*,
out: Optional[np.ndarray] = None,
order: Optional[str] = None,
Expand Down
16 changes: 15 additions & 1 deletion ivy/functional/backends/paddle/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,12 +908,26 @@ def put_along_axis(
]


@with_supported_dtypes(
{
"2.6.0 and below": (
"int32",
"int64",
"float64",
"complex128",
"float32",
"complex64",
"bool",
)
},
backend_version,
)
@handle_out_argument
def unflatten(
x: paddle.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: int = 0,
*,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ def trim_zeros(a: tf.Tensor, /, *, trim: Optional[str] = "bf") -> tf.Tensor:
def unflatten(
x: tf.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: Optional[int] = 0,
*,
out: Optional[tf.Tensor] = None,
name: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ def trim_zeros(a: torch.Tensor, /, *, trim: Optional[str] = "bf") -> torch.Tenso
def unflatten(
x: torch.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
dim: Optional[int] = 0,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2884,9 +2884,9 @@ def trim_zeros(
def unflatten(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
dim: int,
shape: Tuple[int],
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Expand a dimension of the input tensor over multiple dimensions.
Expand Down Expand Up @@ -2930,4 +2930,4 @@ def unflatten(
>>> ivy.unflatten(torch.randn(5, 12, 3), dim=-2, shape=(2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])
"""
return current_backend(x).unflatten(x, dim=dim, shape=shape, out=out)
return ivy.current_backend(x).unflatten(x, dim=dim, shape=shape, out=out)
Original file line number Diff line number Diff line change
Expand Up @@ -1788,8 +1788,8 @@ def test_torch_triu_indices(
),
get_axis=helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
max_size=1,
min_size=1,
max_size=0,
min_size=0,
force_int=True,
),
)
Expand All @@ -1804,10 +1804,9 @@ def test_torch_unflatten(
shape,
get_axis,
):
if type(get_axis) is not tuple:
axis = get_axis
else:
axis = 0 if get_axis is None else get_axis[0]
axis = get_axis
if type(axis) is tuple:
axis = 0 if not get_axis else get_axis[0]
dtype, x = dtype_and_values

def factorization(n):
Expand Down Expand Up @@ -1835,7 +1834,8 @@ def get_factor(n):
next = get_factor(n)
factors.append(next)
n //= next

if len(factors) > 1:
factors.remove(1)
return factors

shape_ = (
Expand Down
Loading
Loading