Skip to content

Commit

Permalink
feat: add Numpy take function and its test (#27462)
Browse files Browse the repository at this point in the history
  • Loading branch information
alvin-98 authored Dec 9, 2023
1 parent cf75d73 commit 9898297
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,9 @@ def tril_indices(n, k=0, m=None):
def unravel_index(indices, shape, order="C"):
ret = [x.astype("int64") for x in ivy.unravel_index(indices, shape)]
return tuple(ret)


@to_ivy_arrays_and_back
@handle_numpy_out
def take(a, indices, /, *, axis=None, out=None, mode="raise"):
return ivy.take(a, indices, axis=axis, out=out, mode=mode)
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,42 @@ def test_numpy_take_along_axis(
indices=indices,
axis=axis,
)


@handle_frontend_test(
fn_tree="numpy.take",
dtype_x_indices_axis=helpers.array_indices_axis(
array_dtypes=helpers.get_dtypes("valid"),
indices_dtypes=["int32", "int64"],
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=5,
indices_same_dims=True,
valid_bounds=False,
),
mode=st.sampled_from(["clip", "wrap"]),
)
def test_numpy_take(
*,
dtype_x_indices_axis,
mode,
test_flags,
frontend,
backend_fw,
fn_tree,
on_device,
):
dtypes, x, indices, axis, _ = dtype_x_indices_axis
helpers.test_frontend_function(
input_dtypes=dtypes,
backend_to_test=backend_fw,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
a=x,
indices=indices,
axis=axis,
mode=mode,
)

0 comments on commit 9898297

Please sign in to comment.