From 2e3d242d7b89da32a426b9e392413798568a580b Mon Sep 17 00:00:00 2001 From: Vismay Suramwar <83938053+Vismay-dev@users.noreply.github.com> Date: Sun, 25 Feb 2024 04:13:16 -0600 Subject: [PATCH] fix: Fix `ivy.permute_dims` for all backends (#28009) Co-authored-by: NripeshN <86844847+NripeshN@users.noreply.github.com> --- ivy/functional/backends/jax/manipulation.py | 3 +++ ivy/functional/backends/numpy/manipulation.py | 3 +++ ivy/functional/backends/paddle/manipulation.py | 17 +++++------------ ivy/functional/backends/torch/manipulation.py | 4 ++++ 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/ivy/functional/backends/jax/manipulation.py b/ivy/functional/backends/jax/manipulation.py index 360333e8a186c..ec3727de6ccd2 100644 --- a/ivy/functional/backends/jax/manipulation.py +++ b/ivy/functional/backends/jax/manipulation.py @@ -76,6 +76,9 @@ def permute_dims( copy: Optional[bool] = None, out: Optional[JaxArray] = None, ) -> JaxArray: + if copy: + newarr = jnp.copy(x) + return jnp.transpose(newarr, axes) return jnp.transpose(x, axes) diff --git a/ivy/functional/backends/numpy/manipulation.py b/ivy/functional/backends/numpy/manipulation.py index 19023b3a23ab9..ee47d4c9b23e7 100644 --- a/ivy/functional/backends/numpy/manipulation.py +++ b/ivy/functional/backends/numpy/manipulation.py @@ -86,6 +86,9 @@ def permute_dims( copy: Optional[bool] = None, out: Optional[np.ndarray] = None, ) -> np.ndarray: + if copy: + newarr = np.copy(x) + return np.transpose(newarr, axes) return np.transpose(x, axes) diff --git a/ivy/functional/backends/paddle/manipulation.py b/ivy/functional/backends/paddle/manipulation.py index aa1a01315f952..c88903542619a 100644 --- a/ivy/functional/backends/paddle/manipulation.py +++ b/ivy/functional/backends/paddle/manipulation.py @@ -108,18 +108,8 @@ def flip( return paddle.flip(x, axis) -@with_supported_dtypes( - { - "2.6.0 and below": ( - "int32", - "int64", - "float64", - "complex128", - "float32", - "complex64", - "bool", - ) - }, +@with_unsupported_dtypes( + {"2.6.0 and below": ("uint8", "int8", "int16", "bfloat16", "float16")}, backend_version, ) def permute_dims( @@ -130,6 +120,9 @@ def permute_dims( copy: Optional[bool] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if copy: + newarr = paddle.clone(x) + return paddle.transpose(newarr, axes) return paddle.transpose(x, axes) diff --git a/ivy/functional/backends/torch/manipulation.py b/ivy/functional/backends/torch/manipulation.py index 023fb00a7409c..0560c89f7e64f 100644 --- a/ivy/functional/backends/torch/manipulation.py +++ b/ivy/functional/backends/torch/manipulation.py @@ -78,6 +78,7 @@ def flip( return torch.flip(x, new_axis) +@with_unsupported_dtypes({"2.1.2 and below": ("bfloat16", "float16")}, backend_version) def permute_dims( x: torch.Tensor, /, @@ -86,6 +87,9 @@ def permute_dims( copy: Optional[bool] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if copy: + newarr = torch.clone(x).detach() + return torch.permute(newarr, axes) return torch.permute(x, axes)