Skip to content

Commit

Permalink
Merge pull request #10288 from YouJiacheng:patch-7
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 442043193
  • Loading branch information
jax authors committed Apr 15, 2022
2 parents 470f58c + 4ff6b1f commit a4b8a44
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def concatenate(self, key_arrs, axis):
return PRNGKeyArray(self.impl, jnp.concatenate(arrs, axis))

def broadcast_to(self, shape):
if jnp.ndim(shape) == 0:
shape = (shape,)
new_shape = (*shape, *self.impl.key_shape)
return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))

Expand Down
3 changes: 3 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,9 @@ def test_broadcast_to(self):
ref = jnp.broadcast_to(like(key), (3,))
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (3,))
out = jnp.broadcast_to(key, 3)
self.assertEqual(out.shape, ref.shape)
self.assertEqual(out.shape, (3,))

def test_expand_dims(self):
key = random.PRNGKey(123)
Expand Down

0 comments on commit a4b8a44

Please sign in to comment.