From 37a010d29d30354b06f7571738cd233f00773377 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Wed, 31 Jan 2024 18:46:14 +0530 Subject: [PATCH] fix: added the tensorflow backend implementation for bernoulli and fixed the backend implementations for the other backends regarding default shape and default dtype (#28139) --- .../backends/jax/experimental/random.py | 3 ++- .../backends/numpy/experimental/random.py | 1 + .../backends/paddle/experimental/random.py | 5 ++++- .../backends/tensorflow/experimental/random.py | 11 ++++++++--- .../tf_probability/experimental/random.py | 1 + .../backends/torch/experimental/random.py | 18 ++++++++---------- .../test_experimental/test_core/test_random.py | 5 ++++- 7 files changed, 28 insertions(+), 16 deletions(-) diff --git a/ivy/functional/backends/jax/experimental/random.py b/ivy/functional/backends/jax/experimental/random.py index 17b2d54037a99..19093eda961ea 100644 --- a/ivy/functional/backends/jax/experimental/random.py +++ b/ivy/functional/backends/jax/experimental/random.py @@ -117,6 +117,7 @@ def bernoulli( seed: Optional[int] = None, out: Optional[JaxArray] = None, ) -> JaxArray: + dtype = dtype if dtype is not None else probs.dtype if seed: rng_input = jax.random.PRNGKey(seed) else: @@ -126,4 +127,4 @@ def bernoulli( probs = jax.nn.softmax(logits, axis=-1) if hasattr(probs, "shape") and not _check_shapes_broadcastable(shape, probs.shape): shape = probs.shape - return jax.random.bernoulli(rng_input, probs, shape=shape) + return jax.random.bernoulli(rng_input, probs, shape=shape).astype(dtype) diff --git a/ivy/functional/backends/numpy/experimental/random.py b/ivy/functional/backends/numpy/experimental/random.py index 16293d4eac407..34960431a3210 100644 --- a/ivy/functional/backends/numpy/experimental/random.py +++ b/ivy/functional/backends/numpy/experimental/random.py @@ -99,6 +99,7 @@ def bernoulli( seed: Optional[int] = None, out: Optional[np.ndarray] = None, ) -> np.ndarray: + dtype = dtype if dtype is not None else probs.dtype if seed is not None: np.random.seed(seed) if logits is not None: diff --git a/ivy/functional/backends/paddle/experimental/random.py b/ivy/functional/backends/paddle/experimental/random.py index 462084349e2ae..e5095b76bc297 100644 --- a/ivy/functional/backends/paddle/experimental/random.py +++ b/ivy/functional/backends/paddle/experimental/random.py @@ -127,6 +127,7 @@ def bernoulli( seed: Optional[int] = None, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + dtype = dtype if dtype is not None else probs.dtype if seed is not None: paddle.seed(seed) if probs is not None: @@ -134,7 +135,9 @@ def bernoulli( elif logits is not None: probs = ivy.softmax(logits) probs = paddle.cast(probs, dtype) - probs = paddle.unsqueeze(probs, 0) if len(probs.shape) == 0 else probs + squeeze = len(probs.shape) == 0 + probs = paddle.unsqueeze(probs, 0) if squeeze else probs probs = paddle.maximum(probs, paddle.full_like(probs, 1e-6)) sample = paddle.bernoulli(probs) + sample = paddle.squeeze(sample, 0) if squeeze else sample return sample diff --git a/ivy/functional/backends/tensorflow/experimental/random.py b/ivy/functional/backends/tensorflow/experimental/random.py index c2742785656e3..2a1bff81f7029 100644 --- a/ivy/functional/backends/tensorflow/experimental/random.py +++ b/ivy/functional/backends/tensorflow/experimental/random.py @@ -90,15 +90,20 @@ def poisson( return ret +@with_unsupported_dtypes({"2.15.0 and below": ("bfloat16",)}, backend_version) def bernoulli( probs: Union[float, tf.Tensor, tf.Variable], *, logits: Union[float, tf.Tensor, tf.Variable] = None, shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None, device: Optional[str] = None, - dtype: DType, + dtype: Optional[str] = None, seed: Optional[int] = None, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: - pass - # TODO: Implement purely in tensorflow + dtype = dtype if dtype is not None else probs.dtype + if logits is not None: + probs = tf.nn.softmax(logits, -1) + if not _check_shapes_broadcastable(shape, probs.shape): + shape = probs.shape + return tf.keras.backend.random_bernoulli(shape, probs, dtype, seed) diff --git a/ivy/functional/backends/tensorflow/sub_backends/tf_probability/experimental/random.py b/ivy/functional/backends/tensorflow/sub_backends/tf_probability/experimental/random.py index ef672977bb4d2..b33cd8304df18 100644 --- a/ivy/functional/backends/tensorflow/sub_backends/tf_probability/experimental/random.py +++ b/ivy/functional/backends/tensorflow/sub_backends/tf_probability/experimental/random.py @@ -63,6 +63,7 @@ def bernoulli( seed: Optional[int] = None, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: + dtype = dtype if dtype is not None else probs.dtype if seed is not None: tf.random.set_seed(seed) if logits is not None: diff --git a/ivy/functional/backends/torch/experimental/random.py b/ivy/functional/backends/torch/experimental/random.py index a4d5b116b0782..5532f5914a42f 100644 --- a/ivy/functional/backends/torch/experimental/random.py +++ b/ivy/functional/backends/torch/experimental/random.py @@ -116,16 +116,14 @@ def bernoulli( seed: Optional[int] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: + dtype = dtype if dtype is not None else probs.dtype if seed: torch.manual_seed(seed) if logits is not None: - if not _check_shapes_broadcastable(shape, logits.shape): - shape = logits.shape - elif probs is not None: - if not _check_shapes_broadcastable(shape, probs.shape): - shape = probs.shape - return ( - torch.distributions.bernoulli.Bernoulli(probs=probs, logits=logits) - .sample(shape) - .to(device, dtype) - ) + probs = torch.nn.functional.softmax(logits, -1) + if not _check_shapes_broadcastable(shape, probs.shape): + shape = probs.shape + return torch.bernoulli(probs, out=out).to(device, dtype).broadcast_to(shape) + + +bernoulli.support_native_out = True diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py index 730c1749af982..0da1bbb87a106 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_random.py @@ -16,6 +16,7 @@ ), seed=helpers.ints(min_value=0, max_value=100), test_gradients=st.just(False), + ground_truth_backend="torch", ) def test_bernoulli( *, dtype_and_probs, seed, test_flags, backend_fw, fn_name, on_device @@ -25,18 +26,20 @@ def test_bernoulli( assume( not ("torch" in str(backend_fw) and "float16" in dtype and on_device == "cpu") ) - helpers.test_function( + ret_np_flat_from_target, ret_np_from_gt_flat = helpers.test_function( input_dtypes=dtype, test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, test_values=False, + return_flat_np_arrays=True, probs=probs[0], logits=None, shape=None, seed=seed, ) + helpers.assert_same_type_and_shape([ret_np_flat_from_target, ret_np_from_gt_flat]) # beta