Skip to content

Commit

Permalink
fix: added the tensorflow backend implementation for bernoulli and fi…
Browse files Browse the repository at this point in the history
…xed the backend implementations for the other backends regarding default shape and default dtype (#28139)
  • Loading branch information
vedpatwardhan authored Jan 31, 2024
1 parent 91a0b62 commit 37a010d
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 16 deletions.
3 changes: 2 additions & 1 deletion ivy/functional/backends/jax/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def bernoulli(
seed: Optional[int] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
dtype = dtype if dtype is not None else probs.dtype
if seed:
rng_input = jax.random.PRNGKey(seed)
else:
Expand All @@ -126,4 +127,4 @@ def bernoulli(
probs = jax.nn.softmax(logits, axis=-1)
if hasattr(probs, "shape") and not _check_shapes_broadcastable(shape, probs.shape):
shape = probs.shape
return jax.random.bernoulli(rng_input, probs, shape=shape)
return jax.random.bernoulli(rng_input, probs, shape=shape).astype(dtype)
1 change: 1 addition & 0 deletions ivy/functional/backends/numpy/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def bernoulli(
seed: Optional[int] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
dtype = dtype if dtype is not None else probs.dtype
if seed is not None:
np.random.seed(seed)
if logits is not None:
Expand Down
5 changes: 4 additions & 1 deletion ivy/functional/backends/paddle/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,17 @@ def bernoulli(
seed: Optional[int] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
dtype = dtype if dtype is not None else probs.dtype
if seed is not None:
paddle.seed(seed)
if probs is not None:
probs = probs
elif logits is not None:
probs = ivy.softmax(logits)
probs = paddle.cast(probs, dtype)
probs = paddle.unsqueeze(probs, 0) if len(probs.shape) == 0 else probs
squeeze = len(probs.shape) == 0
probs = paddle.unsqueeze(probs, 0) if squeeze else probs
probs = paddle.maximum(probs, paddle.full_like(probs, 1e-6))
sample = paddle.bernoulli(probs)
sample = paddle.squeeze(sample, 0) if squeeze else sample
return sample
11 changes: 8 additions & 3 deletions ivy/functional/backends/tensorflow/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,20 @@ def poisson(
return ret


@with_unsupported_dtypes({"2.15.0 and below": ("bfloat16",)}, backend_version)
def bernoulli(
probs: Union[float, tf.Tensor, tf.Variable],
*,
logits: Union[float, tf.Tensor, tf.Variable] = None,
shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
device: Optional[str] = None,
dtype: DType,
dtype: Optional[str] = None,
seed: Optional[int] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
pass
# TODO: Implement purely in tensorflow
dtype = dtype if dtype is not None else probs.dtype
if logits is not None:
probs = tf.nn.softmax(logits, -1)
if not _check_shapes_broadcastable(shape, probs.shape):
shape = probs.shape
return tf.keras.backend.random_bernoulli(shape, probs, dtype, seed)
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def bernoulli(
seed: Optional[int] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
dtype = dtype if dtype is not None else probs.dtype
if seed is not None:
tf.random.set_seed(seed)
if logits is not None:
Expand Down
18 changes: 8 additions & 10 deletions ivy/functional/backends/torch/experimental/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,14 @@ def bernoulli(
seed: Optional[int] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
dtype = dtype if dtype is not None else probs.dtype
if seed:
torch.manual_seed(seed)
if logits is not None:
if not _check_shapes_broadcastable(shape, logits.shape):
shape = logits.shape
elif probs is not None:
if not _check_shapes_broadcastable(shape, probs.shape):
shape = probs.shape
return (
torch.distributions.bernoulli.Bernoulli(probs=probs, logits=logits)
.sample(shape)
.to(device, dtype)
)
probs = torch.nn.functional.softmax(logits, -1)
if not _check_shapes_broadcastable(shape, probs.shape):
shape = probs.shape
return torch.bernoulli(probs, out=out).to(device, dtype).broadcast_to(shape)


bernoulli.support_native_out = True
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
),
seed=helpers.ints(min_value=0, max_value=100),
test_gradients=st.just(False),
ground_truth_backend="torch",
)
def test_bernoulli(
*, dtype_and_probs, seed, test_flags, backend_fw, fn_name, on_device
Expand All @@ -25,18 +26,20 @@ def test_bernoulli(
assume(
not ("torch" in str(backend_fw) and "float16" in dtype and on_device == "cpu")
)
helpers.test_function(
ret_np_flat_from_target, ret_np_from_gt_flat = helpers.test_function(
input_dtypes=dtype,
test_flags=test_flags,
on_device=on_device,
backend_to_test=backend_fw,
fn_name=fn_name,
test_values=False,
return_flat_np_arrays=True,
probs=probs[0],
logits=None,
shape=None,
seed=seed,
)
helpers.assert_same_type_and_shape([ret_np_flat_from_target, ret_np_from_gt_flat])


# beta
Expand Down

0 comments on commit 37a010d

Please sign in to comment.