diff --git a/src/awkward/_reducers.py b/src/awkward/_reducers.py index 5c413840d7..61aa6a7f32 100644 --- a/src/awkward/_reducers.py +++ b/src/awkward/_reducers.py @@ -129,6 +129,14 @@ class ArgMin(KernelReducer): preferred_dtype: Final = np.int64 needs_position: Final = True + @classmethod + def _dtype_for_kernel(cls, dtype: DTypeLike) -> DTypeLike: + dtype = np.dtype(dtype) + if dtype == np.bool_: + return np.dtype(np.int8) + else: + return super()._dtype_for_kernel(dtype) + def apply( self, array: ak.contents.NumpyArray, @@ -183,6 +191,14 @@ class ArgMax(KernelReducer): preferred_dtype: Final = np.int64 needs_position: Final = True + @classmethod + def _dtype_for_kernel(cls, dtype: DTypeLike) -> DTypeLike: + dtype = np.dtype(dtype) + if dtype == np.bool_: + return np.dtype(np.int8) + else: + return super()._dtype_for_kernel(dtype) + def apply( self, array: ak.contents.NumpyArray, diff --git a/tests/test_2934_argmin_argmax_bool.py b/tests/test_2934_argmin_argmax_bool.py new file mode 100644 index 0000000000..e8dd3dab0c --- /dev/null +++ b/tests/test_2934_argmin_argmax_bool.py @@ -0,0 +1,12 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import awkward as ak + + +def test(): + array = ak.Array([[True, False, True], [False, False, True, False], [True, False]]) + + assert ak.argmin(array, axis=1).tolist() == [1, 0, 1] + assert ak.argmax(array, axis=1).tolist() == [0, 2, 0]