Skip to content

Commit

Permalink
Merge branch 'unifyai:main' into index_add#26392
Browse files Browse the repository at this point in the history
  • Loading branch information
imsoumya18 authored Oct 7, 2023
2 parents 1d131ac + 373b033 commit 74bc04b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
10 changes: 10 additions & 0 deletions ivy/functional/frontends/paddle/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
)


@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64")},
"paddle",
)
@to_ivy_arrays_and_back
def multinomial(x, num_samples=1, replacement=False, name=None):
n = num_samples + 1
return ivy.multinomial(n, num_samples, probs=x, replace=replacement)


@with_supported_dtypes(
{"2.5.1 and below": ("float32", "float64")},
"paddle",
Expand Down
58 changes: 58 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,64 @@
from ivy_tests.test_ivy.helpers import handle_frontend_test


# --- Helpers --- #
# --------------- #


@st.composite
def _multinomial_helper(draw):
input_dtype_and_x = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=helpers.get_shape(min_num_dims=1, max_num_dims=2, min_dim_size=2),
)
)
num_samples = draw(st.integers(min_value=1, max_value=10))
if num_samples > 2:
replacement = True
else:
replacement = draw(st.booleans())

input_dtype, x = input_dtype_and_x

total = sum(x)
x = [arr / total for arr in x]

return input_dtype, x, num_samples, replacement


# --- Main --- #
# ------------ #


# multinomial
@handle_frontend_test(
fn_tree="paddle.tensor.random.multinomial",
input_dtype_and_x=_multinomial_helper(),
)
def test_paddle_multinomial(
input_dtype_and_x,
test_flags,
frontend,
backend_fw,
fn_tree,
on_device,
):
input_dtype, x, num_samples, replacement = input_dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
test_flags=test_flags,
frontend=frontend,
backend_to_test=backend_fw,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
x=x[0],
num_samples=num_samples,
replacement=replacement,
)


@handle_frontend_test(
fn_tree="paddle.normal",
input_dtypes=st.sampled_from([["float32"], ["float64"]]),
Expand Down

0 comments on commit 74bc04b

Please sign in to comment.