Skip to content

Commit

Permalink
Better checks everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Feb 13, 2024
1 parent 688f286 commit bcfc919
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/awkward/_connect/jax/reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def apply(
return ak.contents.NumpyArray(
array.backend.nplike.asarray(result, dtype=array.dtype)
)
elif array.dtype.type in (np.complex128, np.complex64):
elif np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(result.view(array.dtype))
else:
return ak.contents.NumpyArray(result, backend=array.backend)
Expand Down Expand Up @@ -209,7 +209,7 @@ def apply(
jax.ops.segment_sum(jax.numpy.log(array.data), parents.data)
)

if array.dtype.type in (np.complex128, np.complex64):
if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
result.view(array.dtype), backend=array.backend
)
Expand Down Expand Up @@ -327,7 +327,7 @@ def apply(
result = jax.ops.segment_min(array.data, parents.data)
result = jax.numpy.minimum(result, self._min_initial(self.initial, array.dtype))

if array.dtype.type in (np.complex128, np.complex64):
if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
array.backend.nplike.asarray(
result.view(array.dtype), dtype=array.dtype
Expand Down Expand Up @@ -388,7 +388,7 @@ def apply(
result = jax.ops.segment_max(array.data, parents.data)

result = jax.numpy.maximum(result, self._max_initial(self.initial, array.dtype))
if array.dtype.type in (np.complex128, np.complex64):
if np.issubdtype(array.dtype, np.complexfloating):
return ak.contents.NumpyArray(
array.backend.nplike.asarray(
result.view(array.dtype), dtype=array.dtype
Expand Down

0 comments on commit bcfc919

Please sign in to comment.