Skip to content

Commit

Permalink
feat: added unflatten frontend and backend support (#27416)
Browse files Browse the repository at this point in the history
Co-authored-by: joaozenobio <zenobiojoao@gmail.com>
  • Loading branch information
Kacper-W-Kozdon and joaozenobio authored Jan 22, 2024
1 parent e62f5b6 commit 12c965a
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 3 deletions.
17 changes: 17 additions & 0 deletions ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import jax.lax as jlax
from numbers import Number
from collections import namedtuple
from ivy.func_wrapper import handle_out_argument

# local
import ivy
Expand Down Expand Up @@ -468,3 +469,19 @@ def take(

def trim_zeros(a: JaxArray, /, *, trim: Optional[str] = "bf") -> JaxArray:
return jnp.trim_zeros(a, trim=trim)


@handle_out_argument
def unflatten(
x: JaxArray,
/,
dim: int = 0,
shape: Tuple[int] = None,
*,
out: Optional[JaxArray] = None,
order: Optional[str] = None,
) -> JaxArray:
dim = abs(len(x.shape) + dim) if dim < 0 else dim
res_shape = x.shape[:dim] + shape + x.shape[dim + 1 :]
res = jnp.reshape(x, res_shape)
return res
18 changes: 17 additions & 1 deletion ivy/functional/backends/numpy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# local
import ivy
from ivy.functional.backends.numpy.helpers import _scalar_output_to_0d_array
from ivy.func_wrapper import with_supported_dtypes
from ivy.func_wrapper import with_supported_dtypes, handle_out_argument

# noinspection PyProtectedMember
from . import backend_version
Expand Down Expand Up @@ -601,3 +601,19 @@ def put_along_axis(
put_along_axis.partial_mixed_handler = lambda *args, mode=None, **kwargs: mode in [
"replace",
]


@handle_out_argument
def unflatten(
x: np.ndarray,
/,
dim: int = 0,
shape: Tuple[int] = None,
*,
out: Optional[np.ndarray] = None,
order: Optional[str] = None,
) -> np.ndarray:
dim = abs(len(x.shape) + dim) if dim < 0 else dim
res_shape = x.shape[:dim] + shape + x.shape[dim + 1 :]
res = np.reshape(x, res_shape)
return res
14 changes: 14 additions & 0 deletions ivy/functional/backends/paddle/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
with_unsupported_device_and_dtypes,
with_supported_dtypes,
with_unsupported_dtypes,
handle_out_argument,
)
import paddle
import ivy
Expand Down Expand Up @@ -905,3 +906,16 @@ def put_along_axis(
"sum",
"mul",
]


@handle_out_argument
def unflatten(
x: paddle.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
*,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
res = paddle.unflatten(x, dim, shape)
return res
18 changes: 17 additions & 1 deletion ivy/functional/backends/tensorflow/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tensorflow as tf

# local
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.func_wrapper import with_unsupported_dtypes, handle_out_argument
from .. import backend_version
import ivy
from ivy.functional.ivy.experimental.manipulation import _to_tf_padding
Expand Down Expand Up @@ -565,3 +565,19 @@ def trim_zeros(a: tf.Tensor, /, *, trim: Optional[str] = "bf") -> tf.Tensor:
last = tf.minimum(last, tf.cast(tf.shape(a)[0], tf.int64))

return a[first:last]


@handle_out_argument
def unflatten(
x: tf.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
*,
out: Optional[tf.Tensor] = None,
name: Optional[str] = None,
) -> tf.Tensor:
dim = abs(len(x.shape) + dim) if dim < 0 else dim
res_shape = x.shape[:dim] + tf.TensorShape(shape) + x.shape[dim + 1 :]
res = tf.reshape(x, res_shape, name)
return res
19 changes: 18 additions & 1 deletion ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@


# local
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from ivy.func_wrapper import (
with_unsupported_dtypes,
with_supported_dtypes,
handle_out_argument,
)
from .. import backend_version
import ivy
from ivy.functional.ivy.experimental.manipulation import (
Expand Down Expand Up @@ -639,3 +643,16 @@ def trim_zeros(a: torch.Tensor, /, *, trim: Optional[str] = "bf") -> torch.Tenso
else:
last = last - 1
return a[first:last]


@handle_out_argument
def unflatten(
x: torch.Tensor,
/,
dim: int = 0,
shape: Tuple[int] = None,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = torch.unflatten(x, dim, shape)
return res
5 changes: 5 additions & 0 deletions ivy/functional/frontends/torch/miscellaneous_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,11 @@ def triu_indices(row, col, offset=0, dtype="int64", device="cpu", layout=None):
return ivy.stack(ivy.nonzero(sample_matrix)).astype(dtype)


@to_ivy_arrays_and_back
def unflatten(input, /, *, dim, sizes):
return ivy.unflatten(input, dim=dim, shape=sizes, out=None)


@to_ivy_arrays_and_back
def vander(x, N=None, increasing=False):
# if N == 0:
Expand Down
61 changes: 61 additions & 0 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2870,3 +2870,64 @@ def trim_zeros(
),
"to_skip": ("inputs_to_ivy_arrays",),
}


@handle_exceptions
@handle_backend_invalid
@handle_nestable
@handle_array_like_without_promotion
@handle_view
@handle_out_argument
@to_native_arrays_and_back
@handle_array_function
@handle_device
def unflatten(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
dim: int,
shape: Tuple[int],
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Expand a dimension of the input tensor over multiple dimensions.
Parameters
----------
x
input tensor.
dim
dimension to be unflattened, specified as an index into input.shape.
shape
new shape of the unflattened dimension. One of its elements can be -1 in
which case the corresponding output dimension is inferred. Otherwise,
the product of sizes must equal input.shape[dim].
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Returns
-------
ret
view of input with the specified dimension unflattened.
This function conforms to the `Array API Standard
<https://data-apis.org/array-api/latest/>`_. This docstring is an extension of the
`docstring <https://data-apis.org/array-api/latest/
API_specification/generated/array_api.permute_dims.html>`_
in the standard.
Both the description and the type hints above assumes an array input for simplicity,
but this function is *nestable*, and therefore also accepts :class:`ivy.Container`
instances in place of any of the arguments.
Examples
--------
>>> ivy.unflatten(torch.randn(3, 4, 1), dim=1, shape=(2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> ivy.unflatten(torch.randn(3, 4, 1), dim=1, shape=(-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> ivy.unflatten(torch.randn(5, 12, 3), dim=-2, shape=(2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])
"""
return current_backend(x).unflatten(x, dim=dim, shape=shape, out=out)
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,86 @@ def test_torch_triu_indices(
)


# unflatten
@handle_frontend_test(
fn_tree="torch.unflatten",
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
dtype_and_values=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
shape_key="shape",
),
get_axis=helpers.get_axis(
shape=st.shared(helpers.get_shape(min_num_dims=1), key="shape"),
max_size=1,
min_size=1,
force_int=True,
),
)
def test_torch_unflatten(
*,
dtype_and_values,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
shape,
get_axis,
):
if type(get_axis) is not tuple:
axis = get_axis
else:
axis = 0 if get_axis is None else get_axis[0]
dtype, x = dtype_and_values

def factorization(n):
factors = [1]

def get_factor(n):
x_fixed = 2
cycle_size = 2
x = 2
factor = 1 if n % 2 else 2

while factor == 1:
for count in range(cycle_size):
if factor > 1:
break
x = (x * x + 1) % n
factor = math.gcd(x - x_fixed, n)

cycle_size *= 2
x_fixed = x

return factor

while n > 1:
next = get_factor(n)
factors.append(next)
n //= next

return factors

shape_ = (
tuple(factorization(shape[axis]))
if tuple(factorization(shape[axis]))
else shape
)
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
input=x[0],
dim=axis,
sizes=shape_,
)


# vander
@handle_frontend_test(
fn_tree="torch.vander",
Expand Down

0 comments on commit 12c965a

Please sign in to comment.