Skip to content

Commit

Permalink
Implement JAX dispatch for Argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 24, 2024
1 parent 3dd1f80 commit 26ba673
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
12 changes: 11 additions & 1 deletion pytensor/link/jax/dispatch/sort.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax import numpy as jnp

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.sort import SortOp
from pytensor.tensor.sort import ArgSortOp, SortOp


@jax_funcify.register(SortOp)
Expand All @@ -12,3 +12,13 @@ def sort(arr, axis):
return jnp.sort(arr, axis=axis, stable=stable)

return sort


@jax_funcify.register(ArgSortOp)
def jax_funcify_ArgSort(op, **kwargs):
stable = op.kind == "stable"

def argsort(arr, axis):
return jnp.argsort(arr, axis=axis, stable=stable)

return argsort
7 changes: 4 additions & 3 deletions tests/link/jax/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import sort
from pytensor.tensor.sort import argsort, sort
from tests.link.jax.test_basic import compare_jax_and_py


@pytest.mark.parametrize("axis", [None, -1])
def test_sort(axis):
@pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = sort(x, axis=axis)
out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])

0 comments on commit 26ba673

Please sign in to comment.