diff --git a/src/awkward/_connect/jax/reducers.py b/src/awkward/_connect/jax/reducers.py index b6f30718a3..469b9b4d67 100644 --- a/src/awkward/_connect/jax/reducers.py +++ b/src/awkward/_connect/jax/reducers.py @@ -178,7 +178,7 @@ def apply( return ak.contents.NumpyArray( array.backend.nplike.asarray(result, dtype=array.dtype) ) - elif array.dtype.type in (np.complex128, np.complex64): + elif np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray(result.view(array.dtype)) else: return ak.contents.NumpyArray(result, backend=array.backend) @@ -209,7 +209,7 @@ def apply( jax.ops.segment_sum(jax.numpy.log(array.data), parents.data) ) - if array.dtype.type in (np.complex128, np.complex64): + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( result.view(array.dtype), backend=array.backend ) @@ -327,7 +327,7 @@ def apply( result = jax.ops.segment_min(array.data, parents.data) result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype)) - if array.dtype.type in (np.complex128, np.complex64): + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( array.backend.nplike.asarray( result.view(array.dtype), dtype=array.dtype @@ -388,7 +388,7 @@ def apply( result = jax.ops.segment_max(array.data, parents.data) result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype)) - if array.dtype.type in (np.complex128, np.complex64): + if np.issubdtype(array.dtype, np.complexfloating): return ak.contents.NumpyArray( array.backend.nplike.asarray( result.view(array.dtype), dtype=array.dtype