From 63c7c0276bdde0efb2fb75b150b52887c3f087ad Mon Sep 17 00:00:00 2001 From: samthakur587 Date: Sun, 21 Jan 2024 01:52:22 +0530 Subject: [PATCH] fix: fixed ivy.cumsum tests at all backend --- ivy/functional/backends/jax/statistical.py | 1 + ivy/functional/backends/tensorflow/statistical.py | 1 + ivy/functional/backends/torch/statistical.py | 4 ++-- .../test_ivy/test_functional/test_core/test_statistical.py | 1 + 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/jax/statistical.py b/ivy/functional/backends/jax/statistical.py index a0b31da7eca77..bff5fca9a2d1e 100644 --- a/ivy/functional/backends/jax/statistical.py +++ b/ivy/functional/backends/jax/statistical.py @@ -183,6 +183,7 @@ def cumprod( return jnp.flip(x, axis=axis) +@with_unsupported_dtypes({"0.4.23 and below": "bool"}, backend_version) def cumsum( x: JaxArray, axis: int = 0, diff --git a/ivy/functional/backends/tensorflow/statistical.py b/ivy/functional/backends/tensorflow/statistical.py index f79a665fb2e99..f111a2804e5a1 100644 --- a/ivy/functional/backends/tensorflow/statistical.py +++ b/ivy/functional/backends/tensorflow/statistical.py @@ -200,6 +200,7 @@ def cumprod( return tf.math.cumprod(x, axis, exclusive, reverse) +@with_unsupported_dtypes({"2.15.0 and below": "bool"}, backend_version) def cumsum( x: Union[tf.Tensor, tf.Variable], axis: int = 0, diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index 7ec801d192d5d..73c2b03fde600 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -290,8 +290,8 @@ def cumprod( # TODO: bfloat16 support is added in PyTorch 1.12.1 @with_unsupported_dtypes( { - "1.12.1 and below": ("uint8", "float16", "bfloat16"), - "1.12.1 and above": ("uint8", "float16"), + "1.12.1 and below": ("uint8", "bool", "float16", "bfloat16"), + "1.12.1 and above": ("uint8", "bool", "float16"), }, backend_version, ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py b/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py index 47d8ef1d894dd..143db6235ea77 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_statistical.py @@ -153,6 +153,7 @@ def test_cumprod( dtype_x_axis_castable=_get_castable_dtype(), exclusive=st.booleans(), reverse=st.booleans(), + test_gradients=st.just(False), ) def test_cumsum( *,