diff --git a/ivy/functional/backends/jax/sorting.py b/ivy/functional/backends/jax/sorting.py index acee3f979316d..fc54befcf72ce 100644 --- a/ivy/functional/backends/jax/sorting.py +++ b/ivy/functional/backends/jax/sorting.py @@ -20,7 +20,7 @@ def argsort( ) -> JaxArray: kind = "stable" if stable else "quicksort" return ( - jnp.argsort(-x, axis=axis, kind=kind) + jnp.argsort(x, axis=axis, kind=kind, descending=descending) if descending else jnp.argsort(x, axis=axis, kind=kind) )