Skip to content

Commit

Permalink
Improve implementation of TF shuffle and make it XLA compilable
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 11, 2024
1 parent 1c70a85 commit 95b4383
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions keras/src/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,10 @@ def dropout(inputs, rate, noise_shape=None, seed=None):


def shuffle(x, axis=0, seed=None):
from keras.src.backend.tensorflow.numpy import swapaxes

seed = _cast_seed(draw_seed(seed))
if axis == 0:
return tf.random.experimental.stateless_shuffle(x, seed=seed)
x = swapaxes(x, axis1=0, axis2=axis)
x = tf.random.experimental.stateless_shuffle(x, seed=seed)
x = swapaxes(x, axis1=0, axis2=axis)
return x
indices = tf.argsort(
tf.random.stateless_uniform(shape=[tf.shape(x)[axis]], seed=seed)
)
return tf.gather(x, indices, axis=axis)


def gamma(shape, alpha, dtype=None, seed=None):
Expand Down

0 comments on commit 95b4383

Please sign in to comment.