diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index 51492cabfd..469b9b4d67 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -111,7 +111,16 @@ def apply( shifts: ak.index.Index | None, outlength: ShapeItem, ) -> ak.contents.NumpyArray: - raise RuntimeError("Cannot differentiate through count_zero") + assert isinstance(array, ak.contents.NumpyArray) + result = jax.numpy.ones_like(array.data, dtype=array.dtype) + result = jax.ops.segment_sum(result, parents.data) + + if np.issubdtype(array.dtype, np.complexfloating): + return ak.contents.NumpyArray( + result.view(array.dtype), backend=array.backend + ) + else: + return ak.contents.NumpyArray(result, backend=array.backend) @overloads(_reducers.CountNonzero) @@ -169,7 +178,7 @@ def apply( return ak.contents.NumpyArray( array.backend.nplike.asarray(result, dtype=array.dtype) ) - elif array.dtype.type in (np.complex128, np.complex64): + elif np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray(result.view(array.dtype)) else: return ak.contents.NumpyArray(result, backend=array.backend) @@ -200,7 +209,7 @@ def apply( jax.ops.segment_sum(jax.numpy.log(array.data), parents.data) ) - if array.dtype.type in (np.complex128, np.complex64): + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( result.view(array.dtype), backend=array.backend ) @@ -318,7 +327,7 @@ def apply( result = jax.ops.segment_min(array.data, parents.data) result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype)) - if array.dtype.type in (np.complex128, np.complex64): + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( array.backend.nplike.asarray( result.view(array.dtype), dtype=array.dtype @@ -379,7 +388,7 @@ def apply( result = jax.ops.segment_max(array.data, parents.data) result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype)) - if array.dtype.type in (np.complex128, np.complex64): + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( array.backend.nplike.asarray( result.view(array.dtype), dtype=array.dtype diff --git a/tests/test_2638_mean_and_count_grads.py b/tests/test_2638_mean_and_count_grads.py new file mode 100644 index 0000000000..12c7934e86 --- /dev/null +++ b/tests/test_2638_mean_and_count_grads.py @@ -0,0 +1,32 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import pytest + +import awkward as ak + +jax = pytest.importorskip("jax") + + +def test(): + ak.jax.register_and_check() + + array = ak.Array([[1.0, 2.0, 3.0], [], [4.0, 5.0]], backend="jax") + + val_mean, grad_mean = jax.value_and_grad(ak.mean, argnums=0)(array) + _, grad_sum = jax.value_and_grad(ak.sum, argnums=0)(array) + val_count, grad_count = jax.value_and_grad(ak.count, argnums=0)(array) + + assert val_mean == 3 + assert ak.all( + grad_mean == ak.Array([[0.2, 0.2, 0.2], [], [0.2, 0.2]], backend="jax") + ) + + # mean is treated as scaled sum + assert ak.all(grad_mean == grad_sum / val_count) + + assert val_count == 5 + assert ak.all( + grad_count == ak.Array([[0.0, 0.0, 0.0], [], [0.0, 0.0]], backend="jax") + )