Skip to content

Commit

Permalink
Initial attempt at NumPy trapz() function - Issue #17939 (#18093)
Browse files Browse the repository at this point in the history
  • Loading branch information
Harrison-O authored Jul 20, 2023
1 parent db869cb commit abff08e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,8 @@ def ediff1d(ary, to_end=None, to_begin=None):
to_end = ivy.array(to_end)
diffs = ivy.concat((diffs, to_end))
return diffs


@to_ivy_arrays_and_back
def trapz(y, x=None, dx=1.0, axis=-1):
return ivy.trapz(y, x=x, dx=dx, axis=axis)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# global
import numpy as np
from hypothesis import strategies as st, assume

# local
Expand Down Expand Up @@ -365,3 +366,76 @@ def test_numpy_ediff1d(
to_end=to_end,
to_begin=to_begin,
)


# trapz
@st.composite
def _either_x_dx(draw):
rand = (draw(st.integers(min_value=0, max_value=1)),)
if rand == 0:
either_x_dx = draw(
helpers.dtype_and_values(
avaliable_dtypes=st.shared(
helpers.get_dtypes("float"), key="trapz_dtype"
),
min_value=-100,
max_value=100,
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
)
)
return rand, either_x_dx
else:
either_x_dx = draw(
st.floats(min_value=-10, max_value=10),
)
return rand, either_x_dx


@handle_frontend_test(
fn_tree="numpy.trapz",
dtype_values_axis=helpers.dtype_values_axis(
available_dtypes=st.shared(helpers.get_dtypes("float"), key="trapz_dtype"),
min_value=-100,
max_value=100,
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
allow_neg_axes=True,
valid_axis=True,
force_int_axis=True,
),
rand_either=_either_x_dx(),
)
def test_numpy_trapz(
dtype_values_axis,
rand_either,
fn_tree,
frontend,
test_flags,
on_device,
):
input_dtype, y, axis = dtype_values_axis
rand, either_x_dx = rand_either
if rand == 0:
dtype_x, x = either_x_dx
x = np.asarray(x, dtype=dtype_x)
dx = None
else:
x = None
dx = either_x_dx
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
fn_tree=fn_tree,
test_flags=test_flags,
on_device=on_device,
y=np.asarray(y[0], dtype=input_dtype[0]),
x=x,
dx=dx,
axis=axis,
)

0 comments on commit abff08e

Please sign in to comment.