Skip to content

Commit

Permalink
Better check for complex types
Browse files Browse the repository at this point in the history
Co-authored-by: Angus Hollands <goosey15@gmail.com>
  • Loading branch information
Saransh-cpp and agoose77 authored Feb 13, 2024
1 parent d9eb323 commit bbb0fcb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def apply(
result = jax.numpy.ones_like(array.data, dtype=array.dtype)
result = jax.ops.segment_sum(result, parents.data)

if array.dtype.type in (np.complex128, np.complex64):
if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
result.view(array.dtype), backend=array.backend
)
Expand Down

0 comments on commit bbb0fcb

Please sign in to comment.