Skip to content

Commit

Permalink
fix: Fix ivy.permute_dims for all backends (#28009)
Browse files Browse the repository at this point in the history
Co-authored-by: NripeshN <86844847+NripeshN@users.noreply.github.com>
  • Loading branch information
vismaysur and NripeshN authored Feb 25, 2024
1 parent 661030a commit 2e3d242
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
3 changes: 3 additions & 0 deletions ivy/functional/backends/jax/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def permute_dims(
copy: Optional[bool] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if copy:
newarr = jnp.copy(x)
return jnp.transpose(newarr, axes)
return jnp.transpose(x, axes)


Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/numpy/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def permute_dims(
copy: Optional[bool] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if copy:
newarr = np.copy(x)
return np.transpose(newarr, axes)
return np.transpose(x, axes)


Expand Down
17 changes: 5 additions & 12 deletions ivy/functional/backends/paddle/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,8 @@ def flip(
return paddle.flip(x, axis)


@with_supported_dtypes(
{
"2.6.0 and below": (
"int32",
"int64",
"float64",
"complex128",
"float32",
"complex64",
"bool",
)
},
@with_unsupported_dtypes(
{"2.6.0 and below": ("uint8", "int8", "int16", "bfloat16", "float16")},
backend_version,
)
def permute_dims(
Expand All @@ -130,6 +120,9 @@ def permute_dims(
copy: Optional[bool] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if copy:
newarr = paddle.clone(x)
return paddle.transpose(newarr, axes)
return paddle.transpose(x, axes)


Expand Down
4 changes: 4 additions & 0 deletions ivy/functional/backends/torch/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def flip(
return torch.flip(x, new_axis)


@with_unsupported_dtypes({"2.1.2 and below": ("bfloat16", "float16")}, backend_version)
def permute_dims(
x: torch.Tensor,
/,
Expand All @@ -86,6 +87,9 @@ def permute_dims(
copy: Optional[bool] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if copy:
newarr = torch.clone(x).detach()
return torch.permute(newarr, axes)
return torch.permute(x, axes)


Expand Down

0 comments on commit 2e3d242

Please sign in to comment.