Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.random.shuffle #907

Merged
merged 6 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras_core/backend/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@
return jax.lax.select(
mask, inputs / keep_prob, jax.numpy.zeros_like(inputs)
)


def shuffle(x, axis=0, seed=None):
seed = jax_draw_seed(seed)
return jax.random.shuffle(seed, x, axis)

Check warning on line 86 in keras_core/backend/jax/random.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/jax/random.py#L85-L86

Added lines #L85 - L86 were not covered by tests
6 changes: 6 additions & 0 deletions keras_core/backend/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,9 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
mask = rng.uniform(size=noise_shape) < keep_prob
mask = np.broadcast_to(mask, inputs.shape)
return np.where(mask, inputs / keep_prob, np.zeros_like(inputs))


def shuffle(x, axis=0, seed=None):
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
return rng.permuted(x, axis=axis)
11 changes: 11 additions & 0 deletions keras_core/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from tensorflow.experimental import numpy as tfnp

from keras_core.backend.common import standardize_dtype
from keras_core.backend.config import floatx
Expand Down Expand Up @@ -83,3 +84,13 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
noise_shape=noise_shape,
seed=seed,
)


def shuffle(x, axis=0, seed=None):
seed = tf_draw_seed(seed)
if axis == 0:
return tf.random.experimental.stateless_shuffle(x, seed=seed)
x = tfnp.swapaxes(x, axis1=0, axis2=axis)
x = tf.random.experimental.stateless_shuffle(x, seed=seed)
x = tfnp.swapaxes(x, axis1=0, axis2=axis)
return x
25 changes: 25 additions & 0 deletions keras_core/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,28 @@
return torch.nn.functional.dropout(
inputs, p=rate, training=True, inplace=False
)


def shuffle(x, axis=0, seed=None):
# Ref: https://github.com/pytorch/pytorch/issues/71409
x = convert_to_tensor(x)

# Get permutation indices
# Do not use generator during symbolic execution.
if get_device() == "meta":
row_perm = torch.rand(x.shape[: axis + 1], device=get_device()).argsort(

Check warning on line 172 in keras_core/backend/torch/random.py

View check run for this annotation

Codecov / codecov/patch

keras_core/backend/torch/random.py#L172

Added line #L172 was not covered by tests
axis
)
else:
generator = torch_seed_generator(seed)
row_perm = torch.rand(
x.shape[: axis + 1], generator=generator, device=get_device()
).argsort(axis)
for _ in range(x.ndim - axis - 1):
row_perm.unsqueeze_(-1)

# Reformat this for the gather operation
row_perm = row_perm.repeat(
*[1 for _ in range(axis + 1)], *(x.shape[axis + 1 :])
)
return x.gather(axis, row_perm)
20 changes: 20 additions & 0 deletions keras_core/random/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,23 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
return backend.random.dropout(
inputs, rate, noise_shape=noise_shape, seed=seed
)


@keras_core_export("keras_core.random.shuffle")
def shuffle(x, axis=0, seed=None):
"""Shuffle the elements of a tensor uniformly at random along an axis.

Args:
x: The tensor to be shuffled.
axis: An integer specifying the axis along which to shuffle. Defaults to
`0`.
seed: A Python integer or instance of
`keras_core.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras_core.random.SeedGenerator`.
"""
return backend.random.shuffle(x, axis=axis, seed=seed)
17 changes: 17 additions & 0 deletions keras_core/random/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,20 @@ def random_numbers(seed):
self.assertGreater(np.abs(y2 - y3), 1e-4)

seed_generator.global_seed_generator().state.assign(seed)

def test_shuffle(self):
x = np.arange(100).reshape(10, 10)

# Test axis=0
y = random.shuffle(x, seed=0)

self.assertFalse(np.all(x == ops.convert_to_numpy(y)))
self.assertAllClose(np.sum(x, axis=0), ops.sum(y, axis=0))
self.assertNotAllClose(np.sum(x, axis=1), ops.sum(y, axis=1))

# Test axis=1
y = random.shuffle(x, axis=1, seed=0)

self.assertFalse(np.all(x == ops.convert_to_numpy(y)))
self.assertAllClose(np.sum(x, axis=1), ops.sum(y, axis=1))
self.assertNotAllClose(np.sum(x, axis=0), ops.sum(y, axis=0))
Loading