Skip to content

Commit

Permalink
Fix keras.ops.quantile implementation for floating point inputs that …
Browse files Browse the repository at this point in the history
…are not tf.float32. (keras-team#20438)
  • Loading branch information
mrry authored and wang-xianghao committed Nov 20, 2024
1 parent 983b377 commit 233574f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,8 @@ def _get_indices(method):
nan_batch_members = tf.reshape(
nan_batch_members, shape=right_rank_matched_shape
)
gathered_y = tf.where(nan_batch_members, float("NaN"), gathered_y)
nan_value = tf.constant(float("NaN"), dtype=x.dtype)
gathered_y = tf.where(nan_batch_members, nan_value, gathered_y)

# Expand dimensions if requested
if keepdims:
Expand Down

0 comments on commit 233574f

Please sign in to comment.