Skip to content

Commit

Permalink
fix: fixed ivy.cumsum tests at all backend (#27974)
Browse files Browse the repository at this point in the history
Co-authored-by: NripeshN <nripesh14@gmail.com>
  • Loading branch information
samthakur587 and NripeshN committed Jan 22, 2024
1 parent 5ec27ac commit f2316e7
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions ivy/functional/backends/jax/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def cumprod(
return jnp.flip(x, axis=axis)


@with_unsupported_dtypes({"0.4.23 and below": "bool"}, backend_version)
def cumsum(
x: JaxArray,
axis: int = 0,
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/tensorflow/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def cumprod(
return tf.math.cumprod(x, axis, exclusive, reverse)


@with_unsupported_dtypes({"2.15.0 and below": "bool"}, backend_version)
def cumsum(
x: Union[tf.Tensor, tf.Variable],
axis: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/torch/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ def cumprod(
# TODO: bfloat16 support is added in PyTorch 1.12.1
@with_unsupported_dtypes(
{
"1.12.1 and below": ("uint8", "float16", "bfloat16"),
"1.12.1 and above": ("uint8", "float16"),
"1.12.1 and below": ("uint8", "bool", "float16", "bfloat16"),
"1.12.1 and above": ("uint8", "bool", "float16"),
},
backend_version,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_cumprod(
dtype_x_axis_castable=_get_castable_dtype(),
exclusive=st.booleans(),
reverse=st.booleans(),
test_gradients=st.just(False),
)
def test_cumsum(
*,
Expand Down

0 comments on commit f2316e7

Please sign in to comment.